Skip to content

Commit 3b9a346

Browse files
authored
feat(filters): Implements a MultiFilter type for complex searches (#311)
Implement a "Multi-Filter" Filter type. It allows: - Create a collection of filters from an input - Allows filters to be groups with and/or logic
1 parent ef95fe8 commit 3b9a346

2 files changed

Lines changed: 742 additions & 6 deletions

File tree

advanced_alchemy/filters.py

Lines changed: 290 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,26 @@
3333
"""
3434

3535
import datetime
36+
import logging
3637
from abc import ABC, abstractmethod
3738
from collections.abc import Collection
3839
from dataclasses import dataclass
3940
from operator import attrgetter
40-
from typing import Any, Callable, Generic, Literal, Optional, Union, cast
41+
from typing import (
42+
Any,
43+
Callable,
44+
ClassVar,
45+
Generic,
46+
Literal,
47+
Optional,
48+
Union,
49+
cast,
50+
)
4151

4252
from sqlalchemy import (
4353
BinaryExpression,
4454
ColumnElement,
55+
Date,
4556
Delete,
4657
Select,
4758
Update,
@@ -56,18 +67,24 @@
5667
true,
5768
)
5869
from sqlalchemy.orm import InstrumentedAttribute
70+
from sqlalchemy.sql import operators as op
5971
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
60-
from typing_extensions import TypeAlias, TypeVar
72+
from typing_extensions import TypeAlias, TypedDict, TypeVar
6173

6274
from advanced_alchemy.base import ModelProtocol
6375

6476
__all__ = (
6577
"BeforeAfter",
6678
"CollectionFilter",
79+
"ComparisonFilter",
6780
"ExistsFilter",
81+
"FilterGroup",
82+
"FilterMap",
6883
"FilterTypes",
6984
"InAnyFilter",
7085
"LimitOffset",
86+
"LogicalOperatorMap",
87+
"MultiFilter",
7188
"NotExistsFilter",
7289
"NotInCollectionFilter",
7390
"NotInSearchFilter",
@@ -89,9 +106,32 @@
89106
ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Select[Any], Update, Delete
90107
],
91108
)
92-
FilterTypes: TypeAlias = "Union[BeforeAfter, OnBeforeAfter, CollectionFilter[Any], LimitOffset, OrderBy, SearchFilter, NotInCollectionFilter[Any], NotInSearchFilter, ExistsFilter, NotExistsFilter]"
109+
FilterTypes: TypeAlias = "Union[BeforeAfter, OnBeforeAfter, CollectionFilter[Any], LimitOffset, OrderBy, SearchFilter, NotInCollectionFilter[Any], NotInSearchFilter, ExistsFilter, NotExistsFilter, ComparisonFilter, MultiFilter, FilterGroup]"
93110
"""Aggregate type alias of the types supported for collection filtering."""
94111

112+
logger = logging.getLogger("advanced_alchemy")
113+
114+
115+
# Define TypedDicts for filter and logical maps
116+
class FilterMap(TypedDict):
117+
before_after: "type[BeforeAfter]"
118+
on_before_after: "type[OnBeforeAfter]"
119+
collection: "type[CollectionFilter[Any]]"
120+
not_in_collection: "type[NotInCollectionFilter[Any]]"
121+
limit_offset: "type[LimitOffset]"
122+
order_by: "type[OrderBy]"
123+
search: "type[SearchFilter]"
124+
not_in_search: "type[NotInSearchFilter]"
125+
comparison: "type[ComparisonFilter]"
126+
exists: "type[ExistsFilter]"
127+
not_exists: "type[NotExistsFilter]"
128+
filter_group: "type[FilterGroup]"
129+
130+
131+
class LogicalOperatorMap(TypedDict):
132+
and_: Callable[..., ColumnElement[bool]]
133+
or_: Callable[..., ColumnElement[bool]]
134+
95135

96136
class StatementFilter(ABC):
97137
"""Abstract base class for SQLAlchemy statement filters.
@@ -519,9 +559,14 @@ def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]
519559
"""
520560
search_clause: list[BinaryExpression[bool]] = []
521561
for field_name in self.normalized_field_names:
522-
field = self._get_instrumented_attr(model, field_name)
523-
search_text = f"%{self.value}%"
524-
search_clause.append(self._func(field)(search_text))
562+
try:
563+
field = self._get_instrumented_attr(model, field_name)
564+
search_text = f"%{self.value}%"
565+
search_clause.append(self._func(field)(search_text))
566+
except AttributeError:
567+
msg = f"Skipping search for field {field_name}. It is not found in model {model.__name__}"
568+
logger.debug(msg)
569+
continue
525570
return search_clause
526571

527572
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
@@ -541,6 +586,57 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
541586
return cast("StatementTypeT", statement.where(where_clause))
542587

543588

589+
@dataclass
590+
class ComparisonFilter(StatementFilter):
591+
"""Simple comparison filter for equality and inequality operations.
592+
593+
This filter applies basic comparison operators (=, !=, >, >=, <, <=) to a field.
594+
It provides a generic way to perform common comparison operations.
595+
596+
Attributes:
597+
----------~
598+
field_name : str
599+
Name of the model attribute to filter on
600+
operator : str
601+
Comparison operator to use ('eq', 'ne', 'gt', 'ge', 'lt', 'le')
602+
value : Any
603+
Value to compare against
604+
605+
Examples:
606+
--------~
607+
>>> filter = SimpleFilter(
608+
... field_name="age", operator="gt", value=18
609+
... )
610+
>>> statement = filter.append_to_statement(select(User), User)
611+
"""
612+
613+
field_name: str
614+
"""Name of the model attribute to filter on."""
615+
operator: str
616+
"""Comparison operator to use (one of 'eq', 'ne', 'gt', 'ge', 'lt', 'le')."""
617+
value: Any
618+
"""Value to compare against."""
619+
620+
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
621+
"""Apply a comparison operation to the statement.
622+
623+
Args:
624+
statement: The SQLAlchemy statement to modify
625+
model: The SQLAlchemy model class
626+
627+
Returns:
628+
StatementTypeT: Modified statement with the comparison condition
629+
"""
630+
field = self._get_instrumented_attr(model, self.field_name)
631+
operator_func = operators_map.get(self.operator)
632+
633+
if operator_func is None:
634+
return statement
635+
636+
condition = operator_func(field, self.value)
637+
return cast("StatementTypeT", statement.where(condition))
638+
639+
544640
@dataclass
545641
class NotInSearchFilter(SearchFilter):
546642
"""Filter for excluding records that match a substring.
@@ -849,3 +945,191 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
849945
# as get_exists_clause handles the empty case by returning true.
850946
exists_clause = self.get_exists_clause(model)
851947
return cast("StatementTypeT", statement.where(exists_clause))
948+
949+
950+
@dataclass
951+
class FilterGroup(StatementFilter):
952+
"""A group of filters combined with a logical operator.
953+
954+
This class combines multiple filters with a logical operator (AND/OR).
955+
It provides a way to create complex nested filter conditions.
956+
957+
Attributes:
958+
----------~
959+
logical_operator : Callable[..., ColumnElement[bool]]
960+
The SQLAlchemy operator to combine filters with (and_, or_)
961+
filters : list[StatementFilter]
962+
List of filters to apply
963+
"""
964+
965+
logical_operator: Callable[..., ColumnElement[bool]]
966+
"""Logical operator to combine the filters (e.g., and_, or_)."""
967+
filters: list[StatementFilter]
968+
"""List of filters to combine."""
969+
970+
def append_to_statement(
971+
self,
972+
statement: StatementTypeT,
973+
model: type[ModelT],
974+
) -> "StatementTypeT":
975+
"""Apply all filters combined with the logical operator.
976+
977+
Args:
978+
statement: The SQLAlchemy statement to modify
979+
model: The SQLAlchemy model class
980+
981+
Returns:
982+
StatementTypeT: Modified statement with combined filters
983+
"""
984+
if not self.filters:
985+
return statement
986+
987+
# Create a list of expressions from each filter
988+
expressions = []
989+
for filter_obj in self.filters:
990+
# Each filter needs to be applied to a clean version of the statement
991+
# to get just its expression
992+
filter_statement = filter_obj.append_to_statement(select(), model)
993+
# Extract the whereclause from the filter's statement
994+
if hasattr(filter_statement, "whereclause") and filter_statement.whereclause is not None:
995+
expressions.append(filter_statement.whereclause) # pyright: ignore
996+
997+
if expressions:
998+
# Combine all expressions with the logical operator
999+
combined = self.logical_operator(*expressions)
1000+
return cast("StatementTypeT", statement.where(combined))
1001+
return statement
1002+
1003+
1004+
# Regular typed dictionary for operators_map
1005+
operators_map: dict[str, Callable[[Any, Any], ColumnElement[bool]]] = {
1006+
"eq": op.eq,
1007+
"ne": op.ne,
1008+
"gt": op.gt,
1009+
"ge": op.ge,
1010+
"lt": op.lt,
1011+
"le": op.le,
1012+
"in": op.in_op,
1013+
"notin": op.notin_op,
1014+
"between": lambda c, v: c.between(v[0], v[1]),
1015+
"like": op.like_op,
1016+
"ilike": op.ilike_op,
1017+
"startswith": op.startswith_op,
1018+
"istartswith": lambda c, v: c.ilike(v + "%"),
1019+
"endswith": op.endswith_op,
1020+
"iendswith": lambda c, v: c.ilike(v + "%"),
1021+
"dateeq": lambda c, v: cast("Date", c) == v,
1022+
}
1023+
1024+
1025+
@dataclass
1026+
class MultiFilter(StatementFilter):
1027+
"""Apply multiple filters to a query based on a JSON/dict input.
1028+
1029+
This filter provides a way to construct complex filter trees from
1030+
a structured dictionary input, supporting nested logical groups and
1031+
various filter types.
1032+
1033+
Attributes:
1034+
----------~
1035+
filters : dict[str, Any]
1036+
Dictionary structure representing the filters, where keys can be
1037+
logical operators ("and_", "or_") and values are lists of filter
1038+
definitions.
1039+
"""
1040+
1041+
filters: dict[str, Any]
1042+
"""JSON/dict structure representing the filters."""
1043+
1044+
# TypedDict class variables
1045+
_filter_map: ClassVar[FilterMap] = {
1046+
"before_after": BeforeAfter,
1047+
"on_before_after": OnBeforeAfter,
1048+
"collection": CollectionFilter,
1049+
"not_in_collection": NotInCollectionFilter,
1050+
"limit_offset": LimitOffset,
1051+
"order_by": OrderBy,
1052+
"search": SearchFilter,
1053+
"not_in_search": NotInSearchFilter,
1054+
"filter_group": FilterGroup,
1055+
"comparison": ComparisonFilter,
1056+
"exists": ExistsFilter,
1057+
"not_exists": NotExistsFilter,
1058+
}
1059+
1060+
_logical_map: ClassVar[LogicalOperatorMap] = {
1061+
"and_": and_,
1062+
"or_": or_,
1063+
}
1064+
1065+
def append_to_statement(
1066+
self,
1067+
statement: StatementTypeT,
1068+
model: type[ModelT],
1069+
) -> StatementTypeT:
1070+
"""Apply the filters to the statement based on the filter definitions.
1071+
1072+
Args:
1073+
statement: The SQLAlchemy statement to modify
1074+
model: The SQLAlchemy model class
1075+
1076+
Returns:
1077+
StatementTypeT: Modified statement with all filters applied
1078+
"""
1079+
for filter_type, conditions in self.filters.items():
1080+
operator = self._logical_map.get(filter_type)
1081+
if operator and isinstance(conditions, list):
1082+
# Create filters from the conditions
1083+
valid_filters = []
1084+
for cond in conditions: # pyright: ignore
1085+
filter_instance = self._create_filter(cond) # pyright: ignore
1086+
if filter_instance is not None:
1087+
valid_filters.append(filter_instance) # pyright: ignore
1088+
1089+
# Only create a filter group if we have valid filters
1090+
if valid_filters:
1091+
filter_group = FilterGroup(
1092+
logical_operator=operator, # type: ignore
1093+
filters=valid_filters, # pyright: ignore
1094+
)
1095+
statement = filter_group.append_to_statement(statement, model)
1096+
return statement
1097+
1098+
def _create_filter(self, condition: dict[str, Any]) -> Optional[StatementFilter]:
1099+
"""Create a filter instance from a condition dictionary.
1100+
1101+
Args:
1102+
condition: Dictionary defining a filter
1103+
1104+
Returns:
1105+
Optional[StatementFilter]: Filter instance if successfully created, None otherwise
1106+
"""
1107+
# Check if condition is a nested logical group
1108+
logical_keys = set(self._logical_map.keys())
1109+
intersect = logical_keys.intersection(condition.keys())
1110+
if intersect:
1111+
# It's a nested filter group
1112+
for key in intersect:
1113+
operator = self._logical_map.get(key)
1114+
if operator and isinstance(condition.get(key), list):
1115+
nested_filters = []
1116+
for cond in condition[key]:
1117+
filter_instance = self._create_filter(cond)
1118+
if filter_instance is not None:
1119+
nested_filters.append(filter_instance) # pyright: ignore
1120+
1121+
if nested_filters:
1122+
return FilterGroup(logical_operator=operator, filters=nested_filters) # type: ignore
1123+
else:
1124+
# Regular filter
1125+
filter_type = condition.get("type")
1126+
if filter_type is not None and isinstance(filter_type, str):
1127+
filter_class = self._filter_map.get(filter_type)
1128+
if filter_class is not None:
1129+
try:
1130+
# Create a copy of the condition without the type key
1131+
filter_args = {k: v for k, v in condition.items() if k != "type"}
1132+
return filter_class(**filter_args) # type: ignore
1133+
except Exception: # noqa: BLE001
1134+
return None
1135+
return None

0 commit comments

Comments
 (0)