Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5482,6 +5482,7 @@ class ResponseFuture(object):
_errbacks = None
_current_host = None
_connection = None
_connection_pool = None
_query_retries = 0
_start_time = None
_metrics = None
Expand Down Expand Up @@ -5578,7 +5579,7 @@ def _on_timeout(self, _attempts=0):
# Capture connection stats before pool.return_connection() can alter state
conn_in_flight = self._connection.in_flight

pool = self.session._get_pool_by_host_identity(self._current_host)
pool = self._connection_pool
if pool and not pool.is_shutdown:
# Do not return the stream ID to the pool yet. We cannot reuse it
# because the node might still be processing the query and will
Expand Down Expand Up @@ -5661,7 +5662,24 @@ def _query(self, host, message=None, cb=None):
if message is None:
message = self.message

pool = self.session._get_pool_by_host_identity(host)
expected_endpoint = None
if isinstance(host, Host):
with host.lock:
expected_endpoint = host.endpoint
Comment on lines +5666 to +5668
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to be under lock for such simple assignment?

pool = self.session._get_pool_by_host_identity(
host, expected_endpoint=expected_endpoint)
else:
pool = self.session._get_pool_by_host_identity(host)

if pool and expected_endpoint is not None:
with host.lock:
endpoint_changed = not self.session._endpoints_match(
host.endpoint, expected_endpoint)
if endpoint_changed:
self._errors[host] = ConnectionException(
"Host endpoint changed while borrowing connection")
return None

Comment on lines +5665 to +5682
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waaait, this is PR from one of your branches to another one of your branches. Why?
I wanted to locally check what is _get_pool_by_host_identity because I don't understand the purpose of this change, but there is not such function.

I don't think you need PRs to make changes to your own work in progress. I think I should revisit this PR when it targets master.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you mention this in cover letter, sorry.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think there is much point in reviewing those changes before the base changes are understood and merged (or at least approved).

if not pool:
self._errors[host] = ConnectionException("Host has been marked down or removed")
return None
Expand All @@ -5678,7 +5696,23 @@ def _query(self, host, message=None, cb=None):
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table)
else:
connection, request_id = pool.borrow_connection(timeout=2.0)

if expected_endpoint is not None:
with host.lock:
endpoint_changed = not self.session._endpoints_match(
host.endpoint, expected_endpoint)
if endpoint_changed:
try:
pool.return_connection(connection)
finally:
connection = None
self._errors[host] = ConnectionException(
"Host endpoint changed while borrowing connection")
return None

self._connection = connection
self._connection_pool = pool

result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []

if cb is None:
Expand Down
83 changes: 54 additions & 29 deletions tests/unit/test_response_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@
# limitations under the License.

import unittest
import uuid

from collections import deque
from threading import RLock
from unittest.mock import Mock, MagicMock, ANY

from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut
from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion
from cassandra.connection import Connection, ConnectionException
from cassandra.connection import Connection, ConnectionException, DefaultEndPoint
from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage,
UnavailableErrorMessage, ResultMessage, QueryMessage,
OverloadedErrorMessage, IsBootstrappingErrorMessage,
PreparedQueryNotFound, PrepareMessage, ServerError,
RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE,
RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED,
ProtocolHandler)
from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy
from cassandra.pool import NoConnectionsAvailable
from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy, SimpleConvictionPolicy
from cassandra.pool import Host, NoConnectionsAvailable
from cassandra.query import SimpleStatement
from tests.util import assertEqual, assertIsInstance
import pytest
Expand All @@ -52,7 +53,7 @@ def make_pool(self):
def make_session(self):
session = self.make_basic_session()
session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2']
session._pools.get.return_value = self.make_pool()
session._get_pool_by_host_identity.return_value = self.make_pool()
return session

def make_response_future(self, session):
Expand All @@ -66,7 +67,7 @@ def make_mock_response(self, col_names, rows):
def test_result_message(self):
session = self.make_basic_session()
session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2']
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
pool.is_shutdown = False

connection = Mock(spec=Connection)
Expand All @@ -75,7 +76,7 @@ def test_result_message(self):
rf = self.make_response_future(session)
rf.send_request()

rf.session._pools.get.assert_called_once_with('ip1')
rf.session._get_pool_by_host_identity.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)

connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
Expand All @@ -87,7 +88,7 @@ def test_result_message(self):

def test_unknown_result_class(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)

Expand Down Expand Up @@ -151,7 +152,7 @@ def test_heartbeat_defunct_deadlock(self):

session = self.make_basic_session()
session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(), Mock()]
session._pools.get.return_value = pool
session._get_pool_by_host_identity.return_value = pool

query = SimpleStatement("SELECT * FROM foo")
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
Expand Down Expand Up @@ -252,7 +253,7 @@ def test_retry_policy_says_ignore(self):

def test_retry_policy_says_retry(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value

query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)")
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM)
Expand All @@ -266,7 +267,7 @@ def test_retry_policy_says_retry(self):
rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy)
rf.send_request()

rf.session._pools.get.assert_called_once_with('ip1')
rf.session._get_pool_by_host_identity.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

Expand All @@ -285,13 +286,13 @@ def test_retry_policy_says_retry(self):

# it should try again with the same host since this was
# an UnavailableException
rf.session._pools.get.assert_called_with(host)
rf.session._get_pool_by_host_identity.assert_called_with(host)
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

def test_retry_with_different_host(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value

connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
Expand All @@ -300,7 +301,7 @@ def test_retry_with_different_host(self):
rf.message.consistency_level = ConsistencyLevel.QUORUM
rf.send_request()

rf.session._pools.get.assert_called_once_with('ip1')
rf.session._get_pool_by_host_identity.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
assert ConsistencyLevel.QUORUM == rf.message.consistency_level
Expand All @@ -319,7 +320,7 @@ def test_retry_with_different_host(self):
rf._retry_task(False, host)

# it should try with a different host
rf.session._pools.get.assert_called_with('ip2')
rf.session._get_pool_by_host_identity.assert_called_with('ip2')
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

Expand All @@ -328,13 +329,13 @@ def test_retry_with_different_host(self):

def test_all_retries_fail(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)

rf = self.make_response_future(session)
rf.send_request()
rf.session._pools.get.assert_called_once_with('ip1')
rf.session._get_pool_by_host_identity.assert_called_once_with('ip1')

result = Mock(spec=IsBootstrappingErrorMessage, info={})
host = Mock()
Expand All @@ -346,7 +347,7 @@ def test_all_retries_fail(self):
rf._retry_task(False, host)

# it should try with a different host
rf.session._pools.get.assert_called_with('ip2')
rf.session._get_pool_by_host_identity.assert_called_with('ip2')

result = Mock(spec=IsBootstrappingErrorMessage, info={})
rf._set_result(host, None, None, result)
Expand All @@ -360,15 +361,15 @@ def test_all_retries_fail(self):

def test_exponential_retry_policy_fail(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)

query = SimpleStatement("SELECT * FROM foo")
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
rf = ResponseFuture(session, message, query, 1, retry_policy=ExponentialBackoffRetryPolicy(2))
rf.send_request()
rf.session._pools.get.assert_called_once_with('ip1')
rf.session._get_pool_by_host_identity.assert_called_once_with('ip1')

result = Mock(spec=IsBootstrappingErrorMessage, info={})
host = Mock()
Expand All @@ -384,7 +385,7 @@ def test_exponential_retry_policy_fail(self):
def test_all_pools_shutdown(self):
session = self.make_basic_session()
session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2']
session._pools.get.return_value.is_shutdown = True
session._get_pool_by_host_identity.return_value.is_shutdown = True

rf = ResponseFuture(session, Mock(), Mock(), 1)
rf.send_request()
Expand All @@ -399,7 +400,7 @@ def test_first_pool_shutdown(self):
pool_shutdown.is_shutdown = True
pool_ok = self.make_pool()
pool_ok.is_shutdown = True
session._pools.get.side_effect = [pool_shutdown, pool_ok]
session._get_pool_by_host_identity.side_effect = [pool_shutdown, pool_ok]

rf = self.make_response_future(session)
rf.send_request()
Expand All @@ -424,7 +425,7 @@ def test_timeout_getting_connection_from_pool(self):
connection = Mock(spec=Connection)
second_pool.borrow_connection.return_value = (connection, 1)

session._pools.get.side_effect = [first_pool, second_pool]
session._get_pool_by_host_identity.side_effect = [first_pool, second_pool]

rf = self.make_response_future(session)
rf.send_request()
Expand Down Expand Up @@ -459,7 +460,7 @@ def test_callback(self):

def test_errback(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)

Expand Down Expand Up @@ -508,7 +509,7 @@ def test_multiple_callbacks(self):

def test_multiple_errbacks(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)

Expand Down Expand Up @@ -581,7 +582,7 @@ def test_add_callbacks(self):

def test_prepared_query_not_found(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)

Expand All @@ -606,7 +607,7 @@ def test_prepared_query_not_found(self):

def test_prepared_query_not_found_bad_keyspace(self):
session = self.make_session()
pool = session._pools.get.return_value
pool = session._get_pool_by_host_identity.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)

Expand Down Expand Up @@ -655,7 +656,7 @@ def test_timeout_does_not_release_stream_id(self):
session = self.make_basic_session()
session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(endpoint='ip1'), Mock(endpoint='ip2')]
pool = self.make_pool()
session._pools.get.return_value = pool
session._get_pool_by_host_identity.return_value = pool
connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(),
orphaned_request_ids=set(), orphaned_threshold=256, in_flight=3)
pool.borrow_connection.return_value = (connection, 1)
Expand All @@ -675,6 +676,30 @@ def test_timeout_does_not_release_stream_id(self):
assert len(connection.request_ids) == 0, \
"Request IDs should be empty but it's not: {}".format(connection.request_ids)

def test_timeout_returns_orphan_to_original_pool_after_endpoint_swap(self):
session = self.make_basic_session()
host = Host(DefaultEndPoint('127.0.0.1'), SimpleConvictionPolicy,
host_id=uuid.uuid4())
session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host]
old_pool = self.make_pool()
replacement_pool = self.make_pool()
connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(),
orphaned_request_ids=set(), orphaned_threshold=256, in_flight=1)
old_pool.borrow_connection.return_value = (connection, 1)
session._get_pool_by_host_identity.side_effect = [old_pool, replacement_pool]

rf = self.make_response_future(session)
rf.send_request()
connection._requests[1] = (connection._handle_options_response,
ProtocolHandler.decode_message, [])
host.endpoint = DefaultEndPoint('127.0.0.2')

rf._on_timeout()

replacement_pool.return_connection.assert_not_called()
old_pool.return_connection.assert_called_once_with(
connection, stream_was_orphaned=True)

def test_single_host_query_plan_exhausted_after_one_retry(self):
"""
Test that when a specific host is provided, the query plan is properly
Expand All @@ -686,7 +711,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self):
"""
session = self.make_basic_session()
pool = self.make_pool()
session._pools.get.return_value = pool
session._get_pool_by_host_identity.return_value = pool

# Create a specific host
specific_host = Mock()
Expand All @@ -702,7 +727,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self):
rf.send_request()

# Verify initial request was sent
rf.session._pools.get.assert_called_once_with(specific_host)
rf.session._get_pool_by_host_identity.assert_called_once_with(specific_host)
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

Expand Down