Skip to content

Commit 3e6ce9b

Browse files
committed
Add rate limit testing assertions
Add a new testing helper module ratelink/testing/assertions.py providing a RateLimitAssertionError and a suite of assertion helpers for rate limiter tests (assert_allowed, assert_denied, assert_remaining, assert_state, assert_allows_n_then_denies, assert_retry_after, assert_limit_equals, assert_eventually_allowed, assert_state_contains). Helpers use limiter.check(key, weight) (weight=0 to peek) and support custom messages, tolerances, and time advancement via a provided time_machine. These utilities simplify writing deterministic tests for rate-limiting behavior and state verification.
1 parent 02f595f commit 3e6ce9b

1 file changed

Lines changed: 176 additions & 0 deletions

File tree

ratelink/testing/assertions.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from typing import Any, Dict, Optional
2+
3+
class RateLimitAssertionError(AssertionError):
4+
pass
5+
6+
def assert_allowed(
7+
limiter: Any,
8+
key: str,
9+
weight: int = 1,
10+
times: int = 1,
11+
message: Optional[str] = None
12+
):
13+
for i in range(times):
14+
allowed, state = limiter.check(key, weight)
15+
if not allowed:
16+
error_msg = message or (
17+
f"Expected key '{key}' to be allowed on attempt {i+1}/{times}, "
18+
f"but it was denied. State: {state}"
19+
)
20+
raise RateLimitAssertionError(error_msg)
21+
22+
def assert_denied(
23+
limiter: Any,
24+
key: str,
25+
weight: int = 1,
26+
message: Optional[str] = None
27+
):
28+
allowed, state = limiter.check(key, weight)
29+
if allowed:
30+
error_msg = message or (
31+
f"Expected key '{key}' to be denied, but it was allowed. "
32+
f"State: {state}"
33+
)
34+
raise RateLimitAssertionError(error_msg)
35+
36+
def assert_remaining(
37+
limiter: Any,
38+
key: str,
39+
expected: int,
40+
tolerance: int = 0,
41+
message: Optional[str] = None
42+
):
43+
allowed, state = limiter.check(key, weight=0) # Weight 0 = peek without consuming
44+
remaining = state.get('remaining', 0)
45+
if tolerance > 0:
46+
if not (expected - tolerance <= remaining <= expected + tolerance):
47+
error_msg = message or (
48+
f"Expected key '{key}' to have {expected} remaining "
49+
f"(±{tolerance}), but got {remaining}. State: {state}"
50+
)
51+
raise RateLimitAssertionError(error_msg)
52+
else:
53+
if remaining != expected:
54+
error_msg = message or (
55+
f"Expected key '{key}' to have {expected} remaining, "
56+
f"but got {remaining}. State: {state}"
57+
)
58+
raise RateLimitAssertionError(error_msg)
59+
60+
def assert_state(
61+
limiter: Any,
62+
key: str,
63+
message: Optional[str] = None,
64+
**expected_values
65+
):
66+
allowed, state = limiter.check(key, weight=0) # Peek
67+
mismatches = []
68+
for key_name, expected_value in expected_values.items():
69+
actual_value = state.get(key_name)
70+
71+
if actual_value != expected_value:
72+
mismatches.append(
73+
f" {key_name}: expected {expected_value}, got {actual_value}"
74+
)
75+
if mismatches:
76+
error_msg = message or (
77+
f"State mismatch for key '{key}':\n" + "\n".join(mismatches) +
78+
f"\nFull state: {state}"
79+
)
80+
raise RateLimitAssertionError(error_msg)
81+
82+
def assert_allows_n_then_denies(
83+
limiter: Any,
84+
key: str,
85+
n: int,
86+
weight: int = 1,
87+
message: Optional[str] = None
88+
):
89+
try:
90+
assert_allowed(limiter, key, weight=weight, times=n)
91+
except RateLimitAssertionError as e:
92+
error_msg = message or f"Expected {n} allowed requests, but failed: {e}"
93+
raise RateLimitAssertionError(error_msg)
94+
try:
95+
assert_denied(limiter, key, weight=weight)
96+
except RateLimitAssertionError as e:
97+
error_msg = message or f"Expected denial after {n} requests, but was allowed: {e}"
98+
raise RateLimitAssertionError(error_msg)
99+
100+
def assert_retry_after(
101+
limiter: Any,
102+
key: str,
103+
min_seconds: float = 0,
104+
max_seconds: Optional[float] = None,
105+
message: Optional[str] = None
106+
):
107+
allowed, state = limiter.check(key, weight=0)
108+
retry_after = state.get('retry_after', 0)
109+
if retry_after < min_seconds:
110+
error_msg = message or (
111+
f"Expected retry_after >= {min_seconds}, but got {retry_after}"
112+
)
113+
raise RateLimitAssertionError(error_msg)
114+
115+
if max_seconds is not None and retry_after > max_seconds:
116+
error_msg = message or (
117+
f"Expected retry_after <= {max_seconds}, but got {retry_after}"
118+
)
119+
raise RateLimitAssertionError(error_msg)
120+
121+
def assert_limit_equals(
122+
limiter: Any,
123+
key: str,
124+
expected_limit: int,
125+
message: Optional[str] = None
126+
):
127+
allowed, state = limiter.check(key, weight=0)
128+
129+
actual_limit = state.get('limit', 0)
130+
131+
if actual_limit != expected_limit:
132+
error_msg = message or (
133+
f"Expected limit of {expected_limit}, but got {actual_limit}"
134+
)
135+
raise RateLimitAssertionError(error_msg)
136+
137+
def assert_eventually_allowed(
138+
limiter: Any,
139+
key: str,
140+
time_machine: Any,
141+
max_advance: float = 300,
142+
step: float = 1,
143+
weight: int = 1,
144+
message: Optional[str] = None
145+
):
146+
total_advanced = 0
147+
148+
while total_advanced < max_advance:
149+
allowed, state = limiter.check(key, weight=weight)
150+
151+
if allowed:
152+
return
153+
154+
time_machine.advance(step)
155+
total_advanced += step
156+
157+
error_msg = message or (
158+
f"Key '{key}' was not allowed even after advancing {max_advance}s"
159+
)
160+
raise RateLimitAssertionError(error_msg)
161+
162+
def assert_state_contains(
163+
limiter: Any,
164+
key: str,
165+
*expected_keys: str,
166+
message: Optional[str] = None
167+
):
168+
allowed, state = limiter.check(key, weight=0)
169+
missing_keys = [k for k in expected_keys if k not in state]
170+
171+
if missing_keys:
172+
error_msg = message or (
173+
f"State for key '{key}' is missing keys: {missing_keys}. "
174+
f"State: {state}"
175+
)
176+
raise RateLimitAssertionError(error_msg)

0 commit comments

Comments
 (0)