diff --git a/advanced_alchemy/repository/_async.py b/advanced_alchemy/repository/_async.py index ae35feff7..0c2e1c174 100644 --- a/advanced_alchemy/repository/_async.py +++ b/advanced_alchemy/repository/_async.py @@ -1503,7 +1503,7 @@ async def list_and_count( Returns: Count of records returned by query, ignoring pagination. """ - self.count_with_window_function = ( + count_with_window_function = ( count_with_window_function if count_with_window_function is not None else self.count_with_window_function ) self.uniquify = self._get_uniquify(uniquify) @@ -1511,7 +1511,7 @@ async def list_and_count( error_messages=error_messages, default_messages=self.error_messages, ) - if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function: + if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function: return await self._list_and_count_basic( *filters, auto_expunge=auto_expunge, diff --git a/advanced_alchemy/repository/_sync.py b/advanced_alchemy/repository/_sync.py index 2ef17f648..674465723 100644 --- a/advanced_alchemy/repository/_sync.py +++ b/advanced_alchemy/repository/_sync.py @@ -1504,7 +1504,7 @@ def list_and_count( Returns: Count of records returned by query, ignoring pagination. """ - self.count_with_window_function = ( + count_with_window_function = ( count_with_window_function if count_with_window_function is not None else self.count_with_window_function ) self.uniquify = self._get_uniquify(uniquify) @@ -1512,7 +1512,7 @@ def list_and_count( error_messages=error_messages, default_messages=self.error_messages, ) - if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function: + if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function: return self._list_and_count_basic( *filters, auto_expunge=auto_expunge, diff --git a/tests/unit/test_repository.py b/tests/unit/test_repository.py index 27919d485..08c162fd3 100644 --- a/tests/unit/test_repository.py +++ b/tests/unit/test_repository.py @@ -527,7 +527,6 @@ async def test_sqlalchemy_repo_list_and_count(mock_repo: SQLAlchemyAsyncReposito """Test expected method calls for list operation.""" mock_instances = [MagicMock(), MagicMock()] mock_count = len(mock_instances) - mocker.patch.object(mock_repo, "_list_and_count_basic", return_value=(mock_instances, mock_count)) mocker.patch.object(mock_repo, "_list_and_count_window", return_value=(mock_instances, mock_count)) instances, instance_count = await maybe_async(mock_repo.list_and_count()) @@ -546,9 +545,8 @@ async def test_sqlalchemy_repo_list_and_count_basic( mock_instances = [MagicMock(), MagicMock()] mock_count = len(mock_instances) mocker.patch.object(mock_repo, "_list_and_count_basic", return_value=(mock_instances, mock_count)) - mocker.patch.object(mock_repo, "_list_and_count_window", return_value=(mock_instances, mock_count)) - instances, instance_count = await maybe_async(mock_repo.list_and_count(count_with_window_function=True)) + instances, instance_count = await maybe_async(mock_repo.list_and_count(count_with_window_function=False)) assert instances == mock_instances assert instance_count == mock_count