11# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
22# to make it harder for the user to import the wrong thing without realizing.
33import io
4+ from contextlib import contextmanager
45from importlib import import_module
56
67from django .conf import settings
@@ -24,6 +25,20 @@ def force_authenticate(request, user=None, token=None):
2425 request ._force_auth_token = token
2526
2627
28+ @contextmanager
29+ def _keep_connections_open ():
30+ """Prevent Django from closing the database connection while a request
31+ is dispatched, matching the behavior of Django's ClientHandler.
32+ """
33+ request_started .disconnect (close_old_connections )
34+ request_finished .disconnect (close_old_connections )
35+ try :
36+ yield
37+ finally :
38+ request_started .connect (close_old_connections )
39+ request_finished .connect (close_old_connections )
40+
41+
2742if requests is not None :
2843 class HeaderDict (requests .packages .urllib3 ._collections .HTTPHeaderDict ):
2944 def get_all (self , key , default ):
@@ -91,17 +106,9 @@ def start_response(wsgi_status, wsgi_headers, exc_info=None):
91106 raw_kwargs ['original_response' ] = MockOriginalResponse (wsgi_headers )
92107
93108 # Make the outgoing request via WSGI.
94- # Disconnect close_old_connections to prevent closing the
95- # database connection during tests, matching the behavior
96- # of Django's ClientHandler.
97109 environ = self .get_environ (request )
98- request_started .disconnect (close_old_connections )
99- request_finished .disconnect (close_old_connections )
100- try :
110+ with _keep_connections_open ():
101111 wsgi_response = self .app (environ , start_response )
102- finally :
103- request_started .connect (close_old_connections )
104- request_finished .connect (close_old_connections )
105112
106113 # Build the underlying urllib3.HTTPResponse
107114 raw_kwargs ['body' ] = io .BytesIO (b'' .join (wsgi_response ))
0 commit comments