Skip to content

Commit b16605d

Browse files
authored
feat: refactor serializers & code cleanup (#661)
Cleans up the serializer logic in the repository and (new) cache module: - Introduced `encode_complex_type` and `decode_complex_type` functions to handle serialization and deserialization of complex types (datetime, date, time, timedelta, Decimal, bytes, UUID, set). - Removed legacy JSON encoder/decoder functions from `serializers.py` and replaced their usage with the new functions. - Added error handling for unsupported types during encoding. - Enhanced import handling for optional dependencies in Alembic scripts.
1 parent 658d417 commit b16605d

12 files changed

Lines changed: 1584 additions & 1312 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/provinzkraut/unasyncd
20-
rev: "v0.9.0"
20+
rev: "v0.10.0"
2121
hooks:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.14.13"
25+
rev: "v0.14.14"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff

advanced_alchemy/_serialization.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
# ruff: noqa: PLR0911
12
import datetime
3+
import decimal
24
import enum
3-
from typing import Any
5+
import uuid
6+
from typing import Any, ClassVar, Protocol, Union, cast
47

58
from typing_extensions import runtime_checkable
69

@@ -11,7 +14,6 @@
1114

1215
PYDANTIC_INSTALLED = True
1316
except ImportError:
14-
from typing import ClassVar, Protocol
1517

1618
@runtime_checkable
1719
class BaseModel(Protocol): # type: ignore[no-redef]
@@ -90,3 +92,95 @@ def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
9092
str: The ISO 8601 formatted date string.
9193
"""
9294
return dt.isoformat()
95+
96+
97+
def encode_complex_type(obj: Any) -> Any:
98+
"""Convert an object to a JSON-serializable format if possible.
99+
100+
Handles types that are not natively JSON serializable:
101+
- datetime, date, time: ISO format strings
102+
- timedelta: total seconds as float
103+
- Decimal: string representation
104+
- bytes: hex string
105+
- UUID: string representation
106+
- set, frozenset: list
107+
108+
Args:
109+
obj: The object to encode.
110+
111+
Returns:
112+
A JSON-serializable representation of the object, or None if the type is not supported.
113+
"""
114+
if isinstance(obj, datetime.datetime):
115+
return {"__type__": "datetime", "value": obj.isoformat()}
116+
if isinstance(obj, datetime.date):
117+
return {"__type__": "date", "value": obj.isoformat()}
118+
if isinstance(obj, datetime.time):
119+
return {"__type__": "time", "value": obj.isoformat()}
120+
if isinstance(obj, datetime.timedelta):
121+
return {"__type__": "timedelta", "value": obj.total_seconds()}
122+
if isinstance(obj, decimal.Decimal):
123+
return {"__type__": "decimal", "value": str(obj)}
124+
if isinstance(obj, bytes):
125+
return {"__type__": "bytes", "value": obj.hex()}
126+
if isinstance(obj, uuid.UUID):
127+
return {"__type__": "uuid", "value": str(obj)}
128+
if isinstance(obj, (set, frozenset)):
129+
items: list[Any] = list(cast("Union[set[Any], frozenset[Any]]", obj)) # type: ignore[redundant-cast]
130+
return {"__type__": "set", "value": items}
131+
return None
132+
133+
134+
def decode_complex_type(value: Any) -> Any:
135+
"""Recursively decode special type markers.
136+
137+
Decodes the special ``{"__type__": ..., "value": ...}`` structures.
138+
"""
139+
if isinstance(value, list):
140+
value_list = cast("list[Any]", value) # type: ignore[redundant-cast]
141+
return [decode_complex_type(v) for v in value_list]
142+
143+
if not isinstance(value, dict):
144+
return value
145+
146+
# Decode any nested values first
147+
value_dict = cast("dict[Any, Any]", value) # type: ignore[redundant-cast]
148+
decoded: dict[str, Any] = {str(k): decode_complex_type(v) for k, v in value_dict.items()}
149+
150+
# Then decode "typed" marker dicts
151+
if "__type__" in decoded and "value" in decoded:
152+
return _decode_typed_marker(decoded)
153+
154+
return decoded
155+
156+
157+
def _decode_typed_marker(obj: dict[str, Any]) -> Any:
158+
"""Custom JSON decoder for special types.
159+
160+
Args:
161+
obj: The dictionary to decode.
162+
163+
Returns:
164+
The decoded object, or the original dict if not a special type.
165+
"""
166+
type_name = obj["__type__"]
167+
value = obj["value"]
168+
169+
if type_name == "datetime":
170+
return datetime.datetime.fromisoformat(value)
171+
if type_name == "date":
172+
return datetime.date.fromisoformat(value)
173+
if type_name == "time":
174+
return datetime.time.fromisoformat(value)
175+
if type_name == "timedelta":
176+
return datetime.timedelta(seconds=value)
177+
if type_name == "decimal":
178+
return decimal.Decimal(value)
179+
if type_name == "bytes":
180+
return bytes.fromhex(value)
181+
if type_name == "uuid":
182+
return uuid.UUID(value)
183+
if type_name == "set":
184+
return set(value)
185+
186+
return obj

advanced_alchemy/alembic/templates/asyncio/script.py.mako

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,21 @@ import sqlalchemy as sa
1313
from alembic import op
1414
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash, FernetBackend
1515
from advanced_alchemy.types.encrypted_string import PGCryptoBackend
16-
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
17-
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
18-
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
1916
from sqlalchemy import Text # noqa: F401
2017
${imports if imports else ""}
18+
try:
19+
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
20+
except ImportError:
21+
Argon2Hasher = Any # type: ignore
22+
try:
23+
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
24+
except ImportError:
25+
PasslibHasher = Any # type: ignore
26+
try:
27+
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
28+
except ImportError:
29+
PwdlibHasher = Any # type: ignore
30+
2131
if TYPE_CHECKING:
2232
from collections.abc import Sequence
2333

advanced_alchemy/alembic/templates/sync/script.py.mako

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,21 @@ import sqlalchemy as sa
1313
from alembic import op
1414
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash, FernetBackend
1515
from advanced_alchemy.types.encrypted_string import PGCryptoBackend
16-
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
17-
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
18-
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
1916
from sqlalchemy import Text # noqa: F401
2017
${imports if imports else ""}
18+
try:
19+
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
20+
except ImportError:
21+
Argon2Hasher = Any # type: ignore
22+
try:
23+
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
24+
except ImportError:
25+
PasslibHasher = Any # type: ignore
26+
try:
27+
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
28+
except ImportError:
29+
PwdlibHasher = Any # type: ignore
30+
2131
if TYPE_CHECKING:
2232
from collections.abc import Sequence
2333

advanced_alchemy/cache/serializers.py

Lines changed: 13 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Serialization utilities for caching SQLAlchemy models."""
22

3-
from datetime import date, datetime, time, timedelta
4-
from decimal import Decimal
5-
from typing import Any, TypeVar, Union, cast
6-
from uuid import UUID
3+
from typing import Any, TypeVar
74

85
from sqlalchemy import inspect as sa_inspect
96

10-
from advanced_alchemy._serialization import decode_json, encode_json
7+
from advanced_alchemy._serialization import (
8+
decode_complex_type,
9+
decode_json,
10+
encode_complex_type,
11+
encode_json,
12+
)
1113

1214
__all__ = (
1315
"default_deserializer",
@@ -23,107 +25,6 @@
2325
"""Metadata key for the table name in serialized data."""
2426

2527

26-
def _json_encoder(obj: Any) -> Any: # noqa: PLR0911
27-
"""Custom JSON encoder for SQLAlchemy model attributes.
28-
29-
Handles types that are not natively JSON serializable:
30-
- datetime, date, time: ISO format strings
31-
- timedelta: total seconds as float
32-
- Decimal: string representation
33-
- bytes: hex string
34-
- UUID: string representation
35-
- set, frozenset: list
36-
37-
Args:
38-
obj: The object to encode.
39-
40-
Returns:
41-
A JSON-serializable representation of the object.
42-
43-
Raises:
44-
TypeError: If the object type is not supported.
45-
"""
46-
if isinstance(obj, datetime):
47-
return {"__type__": "datetime", "value": obj.isoformat()}
48-
if isinstance(obj, date):
49-
return {"__type__": "date", "value": obj.isoformat()}
50-
if isinstance(obj, time):
51-
return {"__type__": "time", "value": obj.isoformat()}
52-
if isinstance(obj, timedelta):
53-
return {"__type__": "timedelta", "value": obj.total_seconds()}
54-
if isinstance(obj, Decimal):
55-
return {"__type__": "decimal", "value": str(obj)}
56-
if isinstance(obj, bytes):
57-
return {"__type__": "bytes", "value": obj.hex()}
58-
if isinstance(obj, UUID):
59-
return {"__type__": "uuid", "value": str(obj)}
60-
if isinstance(obj, (set, frozenset)):
61-
items: list[Any] = list(cast("Union[set[Any], frozenset[Any]]", obj)) # type: ignore[redundant-cast]
62-
return {"__type__": "set", "value": items}
63-
msg = f"Object of type {type(obj).__name__} is not JSON serializable"
64-
raise TypeError(msg)
65-
66-
67-
def _json_decoder(obj: dict[str, Any]) -> Any: # noqa: PLR0911
68-
"""Custom JSON decoder for special types.
69-
70-
Args:
71-
obj: The dictionary to decode.
72-
73-
Returns:
74-
The decoded object, or the original dict if not a special type.
75-
"""
76-
if "__type__" not in obj:
77-
return obj
78-
79-
type_name = obj["__type__"]
80-
value = obj["value"]
81-
82-
if type_name == "datetime":
83-
return datetime.fromisoformat(value)
84-
if type_name == "date":
85-
return date.fromisoformat(value)
86-
if type_name == "time":
87-
return time.fromisoformat(value)
88-
if type_name == "timedelta":
89-
return timedelta(seconds=value)
90-
if type_name == "decimal":
91-
return Decimal(value)
92-
if type_name == "bytes":
93-
return bytes.fromhex(value)
94-
if type_name == "uuid":
95-
return UUID(value)
96-
if type_name == "set":
97-
return set(value)
98-
99-
return obj
100-
101-
102-
def _decode_special_types(value: Any) -> Any:
103-
"""Recursively decode special type markers.
104-
105-
When using ``encode_json`` (msgspec/orjson/json fallback), we can't rely on
106-
stdlib json's ``object_hook`` callback. This helper decodes the special
107-
``{"__type__": ..., "value": ...}`` structures produced by ``_json_encoder``.
108-
"""
109-
if isinstance(value, list):
110-
value_list = cast("list[Any]", value) # type: ignore[redundant-cast]
111-
return [_decode_special_types(v) for v in value_list]
112-
113-
if not isinstance(value, dict):
114-
return value
115-
116-
# Decode any nested values first
117-
value_dict = cast("dict[Any, Any]", value) # type: ignore[redundant-cast]
118-
decoded: dict[str, Any] = {str(k): _decode_special_types(v) for k, v in value_dict.items()}
119-
120-
# Then decode "typed" marker dicts
121-
if "__type__" in decoded and "value" in decoded:
122-
return _json_decoder(decoded)
123-
124-
return decoded
125-
126-
12728
def default_serializer(model: Any) -> bytes:
12829
"""Serialize a SQLAlchemy model instance to JSON bytes.
12930
@@ -160,11 +61,11 @@ def default_serializer(model: Any) -> bytes:
16061
if getattr(column, "_insert_sentinel", False):
16162
continue
16263
value = getattr(model, column.key)
163-
try:
164-
# Encode special types into JSON-friendly marker structures.
165-
data[column.key] = _json_encoder(value)
166-
except TypeError:
167-
# Leave unknown types alone; encode_json has its own hooks/fallbacks.
64+
65+
# Encode special types into JSON-friendly marker structures.
66+
if (encoded := encode_complex_type(value)) is not None:
67+
data[column.key] = encoded
68+
else:
16869
data[column.key] = value
16970

17071
return encode_json(data).encode("utf-8")
@@ -201,7 +102,7 @@ def default_deserializer(data: bytes, model_class: type[T]) -> T:
201102
# user is a detached User instance
202103
"""
203104
parsed_raw = decode_json(data)
204-
parsed = _decode_special_types(parsed_raw)
105+
parsed = decode_complex_type(parsed_raw)
205106

206107
# Validate model class matches
207108
serialized_model = parsed.pop(_MODEL_KEY, None)

0 commit comments

Comments
 (0)