diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index 90f0b40356..cbd4cd3894 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -182,6 +182,7 @@ The available decorators are: * `@parser_classes(...)` * `@authentication_classes(...)` * `@throttle_classes(...)` +* `@throttle_scope(...)` * `@permission_classes(...)` * `@content_negotiation_class(...)` * `@metadata_class(...)` diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index a69a613ba4..c34778991c 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -67,6 +67,9 @@ def handler(self, *args, **kwargs): WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes', APIView.throttle_classes) + WrappedAPIView.throttle_scope = getattr(func, 'throttle_scope', + None) + WrappedAPIView.permission_classes = getattr(func, 'permission_classes', APIView.permission_classes) @@ -136,6 +139,14 @@ def decorator(func): return decorator +def throttle_scope(throttle_scope): + def decorator(func): + _check_decorator_order(func, 'throttle_scope') + func.throttle_scope = throttle_scope + return func + return decorator + + def permission_classes(permission_classes): def decorator(func): _check_decorator_order(func, 'permission_classes') diff --git a/tests/test_decorators.py b/tests/test_decorators.py index cc7cab4d7d..bc6f3a23e4 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -8,7 +8,7 @@ from rest_framework.decorators import ( action, api_view, authentication_classes, content_negotiation_class, metadata_class, parser_classes, permission_classes, renderer_classes, - schema, throttle_classes, versioning_class + schema, throttle_classes, throttle_scope, versioning_class ) from rest_framework.negotiation import BaseContentNegotiation from rest_framework.parsers import JSONParser @@ -17,7 +17,7 @@ from rest_framework.response import Response from rest_framework.schemas import AutoSchema from rest_framework.test import APIRequestFactory -from rest_framework.throttling import UserRateThrottle +from rest_framework.throttling import ScopedRateThrottle, UserRateThrottle from rest_framework.versioning import QueryParameterVersioning from rest_framework.views import APIView @@ -153,6 +153,31 @@ def view(request): response = view(request) assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + def test_throttle_scope(self): + scope = "x" + + class OncePerDayScopedThrottle(ScopedRateThrottle): + THROTTLE_RATES = {scope: "1/day"} + + @api_view(['GET']) + @throttle_classes([OncePerDayScopedThrottle]) + @throttle_scope(scope) + def view_1(request): + return Response({}) + + @api_view(['GET']) + @throttle_classes([OncePerDayScopedThrottle]) + @throttle_scope(scope) + def view_2(request): + return Response({}) + + request = self.factory.get('/') + response = view_1(request) + assert response.status_code == status.HTTP_200_OK + + response = view_2(request) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + def test_versioning_class(self): @api_view(["GET"]) @versioning_class(QueryParameterVersioning)