-
-
Notifications
You must be signed in to change notification settings - Fork 70
Expand file tree
/
Copy path_serialization.py
More file actions
194 lines (152 loc) · 6.51 KB
/
_serialization.py
File metadata and controls
194 lines (152 loc) · 6.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# ruff: noqa: PLR0911
import datetime
import decimal
import enum
import uuid
from typing import Any, ClassVar, Protocol, Union, cast
from typing_extensions import runtime_checkable
from advanced_alchemy.exceptions import MissingDependencyError
try:
from pydantic import BaseModel # type: ignore
PYDANTIC_INSTALLED = True
except ImportError:
@runtime_checkable
class BaseModel(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""
model_fields: ClassVar[dict[str, Any]]
def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
"""Placeholder for pydantic.BaseModel.model_dump_json
Returns:
The JSON representation of the model.
"""
msg = "pydantic"
raise MissingDependencyError(msg)
PYDANTIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
def _type_to_string(value: Any) -> str: # pragma: no cover
if isinstance(value, datetime.datetime):
return convert_datetime_to_gmt_iso(value)
if isinstance(value, datetime.date):
return convert_date_to_iso(value)
if isinstance(value, enum.Enum):
return str(value.value)
if PYDANTIC_INSTALLED and isinstance(value, BaseModel):
return value.model_dump_json()
try:
val = str(value)
except Exception as exc:
raise TypeError from exc
return val
try:
from msgspec.json import Decoder, Encoder
encoder, decoder = Encoder(enc_hook=_type_to_string), Decoder()
decode_json = decoder.decode
def encode_json(data: Any) -> str: # pragma: no cover
return encoder.encode(data).decode("utf-8")
except ImportError:
try:
from orjson import ( # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
OPT_NAIVE_UTC, # pyright: ignore[reportUnknownVariableType]
OPT_SERIALIZE_NUMPY, # pyright: ignore[reportUnknownVariableType]
OPT_SERIALIZE_UUID, # pyright: ignore[reportUnknownVariableType]
)
from orjson import ( # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
dumps as _encode_json, # pyright: ignore[reportUnknownVariableType]
)
from orjson import ( # type: ignore[no-redef,assignment,import-not-found] # pyright: ignore[reportMissingImports]
loads as decode_json, # pyright: ignore[reportUnknownVariableType,reportUnusedImport]
)
def encode_json(data: Any) -> str: # pragma: no cover
return _encode_json( # type: ignore[no-any-return]
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
).decode("utf-8")
except ImportError:
from json import dumps as encode_json # type: ignore[assignment] # noqa: F401
from json import loads as decode_json # type: ignore[assignment] # noqa: F401
def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps.
Returns:
str: The ISO 8601 formatted datetime string.
"""
if not dt.tzinfo:
dt = dt.replace(tzinfo=datetime.timezone.utc)
return dt.isoformat().replace("+00:00", "Z")
def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
"""Handle datetime serialization for nested timestamps.
Returns:
str: The ISO 8601 formatted date string.
"""
return dt.isoformat()
def encode_complex_type(obj: Any) -> Any:
"""Convert an object to a JSON-serializable format if possible.
Handles types that are not natively JSON serializable:
- datetime, date, time: ISO format strings
- timedelta: total seconds as float
- Decimal: string representation
- bytes: hex string
- UUID: string representation
- set, frozenset: list
Args:
obj: The object to encode.
Returns:
A JSON-serializable representation of the object, or None if the type is not supported.
"""
if isinstance(obj, datetime.datetime):
return {"__type__": "datetime", "value": obj.isoformat()}
if isinstance(obj, datetime.date):
return {"__type__": "date", "value": obj.isoformat()}
if isinstance(obj, datetime.time):
return {"__type__": "time", "value": obj.isoformat()}
if isinstance(obj, datetime.timedelta):
return {"__type__": "timedelta", "value": obj.total_seconds()}
if isinstance(obj, decimal.Decimal):
return {"__type__": "decimal", "value": str(obj)}
if isinstance(obj, bytes):
return {"__type__": "bytes", "value": obj.hex()}
if isinstance(obj, uuid.UUID):
return {"__type__": "uuid", "value": str(obj)}
if isinstance(obj, (set, frozenset)):
items: list[Any] = list(cast("Union[set[Any], frozenset[Any]]", obj)) # type: ignore[redundant-cast]
return {"__type__": "set", "value": items}
return None
def decode_complex_type(value: Any) -> Any:
"""Recursively decode special type markers.
Decodes the special ``{"__type__": ..., "value": ...}`` structures.
"""
if isinstance(value, list):
value_list = cast("list[Any]", value) # type: ignore[redundant-cast]
return [decode_complex_type(v) for v in value_list]
if not isinstance(value, dict):
return value
# Decode any nested values first
value_dict = cast("dict[Any, Any]", value) # type: ignore[redundant-cast]
decoded: dict[str, Any] = {str(k): decode_complex_type(v) for k, v in value_dict.items()}
# Then decode "typed" marker dicts
if "__type__" in decoded and "value" in decoded:
return _decode_typed_marker(decoded)
return decoded
def _decode_typed_marker(obj: dict[str, Any]) -> Any:
"""Custom JSON decoder for special types.
Args:
obj: The dictionary to decode.
Returns:
The decoded object, or the original dict if not a special type.
"""
type_name = obj["__type__"]
value = obj["value"]
if type_name == "datetime":
return datetime.datetime.fromisoformat(value)
if type_name == "date":
return datetime.date.fromisoformat(value)
if type_name == "time":
return datetime.time.fromisoformat(value)
if type_name == "timedelta":
return datetime.timedelta(seconds=value)
if type_name == "decimal":
return decimal.Decimal(value)
if type_name == "bytes":
return bytes.fromhex(value)
if type_name == "uuid":
return uuid.UUID(value)
if type_name == "set":
return set(value)
return obj