Skip to content

Commit a4d2f30

Browse files
committed
Refactor disconnect + connect into a context manager
1 parent 7fbbf9b commit a4d2f30

1 file changed

Lines changed: 16 additions & 9 deletions

File tree

rest_framework/test.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
22
# to make it harder for the user to import the wrong thing without realizing.
33
import io
4+
from contextlib import contextmanager
45
from importlib import import_module
56

67
from django.conf import settings
@@ -24,6 +25,20 @@ def force_authenticate(request, user=None, token=None):
2425
request._force_auth_token = token
2526

2627

28+
@contextmanager
29+
def _keep_connections_open():
30+
"""Prevent Django from closing the database connection while a request
31+
is dispatched, matching the behavior of Django's ClientHandler.
32+
"""
33+
request_started.disconnect(close_old_connections)
34+
request_finished.disconnect(close_old_connections)
35+
try:
36+
yield
37+
finally:
38+
request_started.connect(close_old_connections)
39+
request_finished.connect(close_old_connections)
40+
41+
2742
if requests is not None:
2843
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
2944
def get_all(self, key, default):
@@ -91,17 +106,9 @@ def start_response(wsgi_status, wsgi_headers, exc_info=None):
91106
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
92107

93108
# Make the outgoing request via WSGI.
94-
# Disconnect close_old_connections to prevent closing the
95-
# database connection during tests, matching the behavior
96-
# of Django's ClientHandler.
97109
environ = self.get_environ(request)
98-
request_started.disconnect(close_old_connections)
99-
request_finished.disconnect(close_old_connections)
100-
try:
110+
with _keep_connections_open():
101111
wsgi_response = self.app(environ, start_response)
102-
finally:
103-
request_started.connect(close_old_connections)
104-
request_finished.connect(close_old_connections)
105112

106113
# Build the underlying urllib3.HTTPResponse
107114
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))

0 commit comments

Comments
 (0)