Skip to content

Commit d8df0f1

Browse files
committed
Refactor disconnect + connect into a context manager
1 parent 7fbbf9b commit d8df0f1

1 file changed

Lines changed: 17 additions & 9 deletions

File tree

rest_framework/test.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
22
# to make it harder for the user to import the wrong thing without realizing.
33
import io
4+
from contextlib import contextmanager
45
from importlib import import_module
56

67
from django.conf import settings
@@ -24,6 +25,21 @@ def force_authenticate(request, user=None, token=None):
2425
request._force_auth_token = token
2526

2627

28+
@contextmanager
29+
def _keep_connections_open():
30+
"""
31+
Prevent Django from closing the database connection while a request
32+
is dispatched, matching the behavior of Django's ClientHandler.
33+
"""
34+
request_started.disconnect(close_old_connections)
35+
request_finished.disconnect(close_old_connections)
36+
try:
37+
yield
38+
finally:
39+
request_started.connect(close_old_connections)
40+
request_finished.connect(close_old_connections)
41+
42+
2743
if requests is not None:
2844
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
2945
def get_all(self, key, default):
@@ -91,17 +107,9 @@ def start_response(wsgi_status, wsgi_headers, exc_info=None):
91107
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
92108

93109
# Make the outgoing request via WSGI.
94-
# Disconnect close_old_connections to prevent closing the
95-
# database connection during tests, matching the behavior
96-
# of Django's ClientHandler.
97110
environ = self.get_environ(request)
98-
request_started.disconnect(close_old_connections)
99-
request_finished.disconnect(close_old_connections)
100-
try:
111+
with _keep_connections_open():
101112
wsgi_response = self.app(environ, start_response)
102-
finally:
103-
request_started.connect(close_old_connections)
104-
request_finished.connect(close_old_connections)
105113

106114
# Build the underlying urllib3.HTTPResponse
107115
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))

0 commit comments

Comments
 (0)