diff --git a/tests/test_cache.py b/tests/test_cache.py index f64920a..75869df 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,4 +1,5 @@ import time +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, cast import pytest @@ -398,3 +399,65 @@ def test_cache_default_ttl_sentinel(): time.sleep(0.2) assert cache.get("key2") == "value2" # Should still be there cache.close() + + +def test_cache_concurrent_get_set_delete_clear(): + """Multiple threads performing mixed operations on a shared Cache must not raise exceptions.""" + num_threads = 10 + num_operations = 200 + cache = Cache( + max_items=50, + size_limit_in_bytes=None, + default_ttl=None, + expiration_thread_max_checks_per_iteration=0, + ) + + def worker(thread_id: int) -> None: + for i in range(num_operations): + key = f"key{i % 20}" + op = (thread_id * num_operations + i) % 4 + if op == 0: + cache.set(key, f"value-{thread_id}-{i}") + elif op == 1: + result = cache.get(key) + assert result is CACHE_MISS or isinstance(result, str) + elif op == 2: + cache.delete(key) + else: + cache.clear() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, t) for t in range(num_threads)] + for future in as_completed(futures): + future.result() + + assert cache.number_of_items <= 50 + cache.close() + + +def test_cache_concurrent_with_ttl(): + """Concurrent Cache access with TTL expiration thread active must not corrupt state.""" + num_threads = 8 + num_operations = 100 + cache = Cache( + max_items=100, + default_ttl=0.05, + expiration_thread_delay=0.01, + expiration_thread_max_checks_per_iteration=50, + ) + + def worker(thread_id: int) -> None: + for i in range(num_operations): + key = f"key{i % 10}" + if i % 2 == 0: + cache.set(key, f"v{thread_id}-{i}") + else: + result = cache.get(key) + assert result is CACHE_MISS or isinstance(result, str) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, t) for t in range(num_threads)] + for future in as_completed(futures): + future.result() + + cache.close(wait=True) diff --git a/tests/test_storage.py b/tests/test_storage.py index 80b55c2..60f3d47 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,5 +1,6 @@ import random import time +from concurrent.futures import ThreadPoolExecutor, as_completed import pytest @@ -314,6 +315,65 @@ def test_overwrite_existing_key_size_tracking(): storage.close() +def test_storage_concurrent_get_set_delete(): + """Multiple threads performing get/set/delete on a shared Storage must not raise exceptions.""" + num_threads = 10 + num_operations = 200 + storage = Storage[bytes]( + size_limit_in_bytes=None, + max_items=50, + expiration_thread_max_checks_per_iteration=0, + ) + + def worker(thread_id: int) -> None: + for i in range(num_operations): + key = f"key{i % 20}" + op = (thread_id * num_operations + i) % 3 + if op == 0: + storage.set(key, f"value-{thread_id}-{i}".encode()) + elif op == 1: + storage.get(key) + else: + storage.delete(key) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, t) for t in range(num_threads)] + for future in as_completed(futures): + future.result() + + assert storage.number_of_items <= 50 + storage.close() + + +def test_storage_concurrent_with_ttl(): + """Concurrent get/set with a live expiration thread must not corrupt state.""" + num_threads = 8 + num_operations = 100 + storage = Storage[bytes]( + size_limit_in_bytes=None, + max_items=100, + default_ttl=0.05, + expiration_thread_delay=0.01, + expiration_thread_max_checks_per_iteration=50, + ) + + def worker(thread_id: int) -> None: + for i in range(num_operations): + key = f"key{i % 10}" + if i % 2 == 0: + storage.set(key, f"v{thread_id}-{i}".encode()) + else: + result = storage.get(key) + assert result is CACHE_MISS or isinstance(result, bytes) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, t) for t in range(num_threads)] + for future in as_completed(futures): + future.result() + + storage.close(wait=True) + + def test_clear(): """Test that clear() removes all items and resets size tracking.""" storage = Storage[bytes](