1+ from typing import Any , Dict , Optional
2+
3+ class RateLimitAssertionError (AssertionError ):
4+ pass
5+
6+ def assert_allowed (
7+ limiter : Any ,
8+ key : str ,
9+ weight : int = 1 ,
10+ times : int = 1 ,
11+ message : Optional [str ] = None
12+ ):
13+ for i in range (times ):
14+ allowed , state = limiter .check (key , weight )
15+ if not allowed :
16+ error_msg = message or (
17+ f"Expected key '{ key } ' to be allowed on attempt { i + 1 } /{ times } , "
18+ f"but it was denied. State: { state } "
19+ )
20+ raise RateLimitAssertionError (error_msg )
21+
22+ def assert_denied (
23+ limiter : Any ,
24+ key : str ,
25+ weight : int = 1 ,
26+ message : Optional [str ] = None
27+ ):
28+ allowed , state = limiter .check (key , weight )
29+ if allowed :
30+ error_msg = message or (
31+ f"Expected key '{ key } ' to be denied, but it was allowed. "
32+ f"State: { state } "
33+ )
34+ raise RateLimitAssertionError (error_msg )
35+
36+ def assert_remaining (
37+ limiter : Any ,
38+ key : str ,
39+ expected : int ,
40+ tolerance : int = 0 ,
41+ message : Optional [str ] = None
42+ ):
43+ allowed , state = limiter .check (key , weight = 0 ) # Weight 0 = peek without consuming
44+ remaining = state .get ('remaining' , 0 )
45+ if tolerance > 0 :
46+ if not (expected - tolerance <= remaining <= expected + tolerance ):
47+ error_msg = message or (
48+ f"Expected key '{ key } ' to have { expected } remaining "
49+ f"(±{ tolerance } ), but got { remaining } . State: { state } "
50+ )
51+ raise RateLimitAssertionError (error_msg )
52+ else :
53+ if remaining != expected :
54+ error_msg = message or (
55+ f"Expected key '{ key } ' to have { expected } remaining, "
56+ f"but got { remaining } . State: { state } "
57+ )
58+ raise RateLimitAssertionError (error_msg )
59+
60+ def assert_state (
61+ limiter : Any ,
62+ key : str ,
63+ message : Optional [str ] = None ,
64+ ** expected_values
65+ ):
66+ allowed , state = limiter .check (key , weight = 0 ) # Peek
67+ mismatches = []
68+ for key_name , expected_value in expected_values .items ():
69+ actual_value = state .get (key_name )
70+
71+ if actual_value != expected_value :
72+ mismatches .append (
73+ f" { key_name } : expected { expected_value } , got { actual_value } "
74+ )
75+ if mismatches :
76+ error_msg = message or (
77+ f"State mismatch for key '{ key } ':\n " + "\n " .join (mismatches ) +
78+ f"\n Full state: { state } "
79+ )
80+ raise RateLimitAssertionError (error_msg )
81+
82+ def assert_allows_n_then_denies (
83+ limiter : Any ,
84+ key : str ,
85+ n : int ,
86+ weight : int = 1 ,
87+ message : Optional [str ] = None
88+ ):
89+ try :
90+ assert_allowed (limiter , key , weight = weight , times = n )
91+ except RateLimitAssertionError as e :
92+ error_msg = message or f"Expected { n } allowed requests, but failed: { e } "
93+ raise RateLimitAssertionError (error_msg )
94+ try :
95+ assert_denied (limiter , key , weight = weight )
96+ except RateLimitAssertionError as e :
97+ error_msg = message or f"Expected denial after { n } requests, but was allowed: { e } "
98+ raise RateLimitAssertionError (error_msg )
99+
100+ def assert_retry_after (
101+ limiter : Any ,
102+ key : str ,
103+ min_seconds : float = 0 ,
104+ max_seconds : Optional [float ] = None ,
105+ message : Optional [str ] = None
106+ ):
107+ allowed , state = limiter .check (key , weight = 0 )
108+ retry_after = state .get ('retry_after' , 0 )
109+ if retry_after < min_seconds :
110+ error_msg = message or (
111+ f"Expected retry_after >= { min_seconds } , but got { retry_after } "
112+ )
113+ raise RateLimitAssertionError (error_msg )
114+
115+ if max_seconds is not None and retry_after > max_seconds :
116+ error_msg = message or (
117+ f"Expected retry_after <= { max_seconds } , but got { retry_after } "
118+ )
119+ raise RateLimitAssertionError (error_msg )
120+
121+ def assert_limit_equals (
122+ limiter : Any ,
123+ key : str ,
124+ expected_limit : int ,
125+ message : Optional [str ] = None
126+ ):
127+ allowed , state = limiter .check (key , weight = 0 )
128+
129+ actual_limit = state .get ('limit' , 0 )
130+
131+ if actual_limit != expected_limit :
132+ error_msg = message or (
133+ f"Expected limit of { expected_limit } , but got { actual_limit } "
134+ )
135+ raise RateLimitAssertionError (error_msg )
136+
137+ def assert_eventually_allowed (
138+ limiter : Any ,
139+ key : str ,
140+ time_machine : Any ,
141+ max_advance : float = 300 ,
142+ step : float = 1 ,
143+ weight : int = 1 ,
144+ message : Optional [str ] = None
145+ ):
146+ total_advanced = 0
147+
148+ while total_advanced < max_advance :
149+ allowed , state = limiter .check (key , weight = weight )
150+
151+ if allowed :
152+ return
153+
154+ time_machine .advance (step )
155+ total_advanced += step
156+
157+ error_msg = message or (
158+ f"Key '{ key } ' was not allowed even after advancing { max_advance } s"
159+ )
160+ raise RateLimitAssertionError (error_msg )
161+
162+ def assert_state_contains (
163+ limiter : Any ,
164+ key : str ,
165+ * expected_keys : str ,
166+ message : Optional [str ] = None
167+ ):
168+ allowed , state = limiter .check (key , weight = 0 )
169+ missing_keys = [k for k in expected_keys if k not in state ]
170+
171+ if missing_keys :
172+ error_msg = message or (
173+ f"State for key '{ key } ' is missing keys: { missing_keys } . "
174+ f"State: { state } "
175+ )
176+ raise RateLimitAssertionError (error_msg )
0 commit comments