Coverage for muutils\json_serialize\serializable_dataclass.py: 54%
247 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
1"""save and load objects to and from json or compatible formats in a recoverable way
3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
6Instead, you define your class:
8```python
9@serializable_dataclass
10class MyClass(SerializableDataclass):
11 a: int
12 b: str
13```
15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
17 >>> my_obj = MyClass(a=1, b="q")
18 >>> s = json.dumps(my_obj.serialize())
19 >>> s
20 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
21 >>> read_obj = MyClass.load(json.loads(s))
22 >>> read_obj == my_obj
23 True
25This isn't too impressive on its own, but it gets more useful when you have nested classses,
26or fields that are not json-serializable by default:
28```python
29@serializable_dataclass
30class NestedClass(SerializableDataclass):
31 x: str
32 y: MyClass
33 act_fun: torch.nn.Module = serializable_field(
34 default=torch.nn.ReLU(),
35 serialization_fn=lambda x: str(x),
36 deserialize_fn=lambda x: getattr(torch.nn, x)(),
37 )
38```
40which gives us:
42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
43 >>> s = json.dumps(nc.serialize())
44 >>> s
45 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
46 >>> read_nc = NestedClass.load(json.loads(s))
47 >>> read_nc == nc
48 True
50"""
52from __future__ import annotations
54import abc
55import dataclasses
56import functools
57import json
58import sys
59import typing
60import warnings
61from typing import Any, Optional, Type, TypeVar
63from muutils.errormode import ErrorMode
64from muutils.validate_type import validate_type
65from muutils.json_serialize.serializable_field import (
66 SerializableField,
67 serializable_field,
68)
69from muutils.json_serialize.util import array_safe_eq, dc_eq
71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
74def _dataclass_transform_mock(
75 *,
76 eq_default: bool = True,
77 order_default: bool = False,
78 kw_only_default: bool = False,
79 frozen_default: bool = False,
80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (),
81 **kwargs: Any,
82) -> typing.Callable:
83 "mock `typing.dataclass_transform` for python <3.11"
85 def decorator(cls_or_fn):
86 cls_or_fn.__dataclass_transform__ = {
87 "eq_default": eq_default,
88 "order_default": order_default,
89 "kw_only_default": kw_only_default,
90 "frozen_default": frozen_default,
91 "field_specifiers": field_specifiers,
92 "kwargs": kwargs,
93 }
94 return cls_or_fn
96 return decorator
99dataclass_transform: typing.Callable
100if sys.version_info < (3, 11):
101 dataclass_transform = _dataclass_transform_mock
102else:
103 dataclass_transform = typing.dataclass_transform
106T = TypeVar("T")
109class CantGetTypeHintsWarning(UserWarning):
110 "special warning for when we can't get type hints"
112 pass
115class ZanjMissingWarning(UserWarning):
116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
118 pass
121_zanj_loading_needs_import: bool = True
122"flag to keep track of if we have successfully imported ZANJ"
125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
126 """Register a serializable dataclass with the ZANJ import
128 this allows `ZANJ().read()` to load the class and not just return plain dicts
131 # TODO: there is some duplication here with register_loader_handler
132 """
133 global _zanj_loading_needs_import
135 if _zanj_loading_needs_import:
136 try:
137 from zanj.loading import ( # type: ignore[import]
138 LoaderHandler,
139 register_loader_handler,
140 )
141 except ImportError:
142 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
143 # warnings.warn(
144 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
145 # ZanjMissingWarning,
146 # )
147 return
149 _format: str = f"{cls.__name__}(SerializableDataclass)"
150 lh: LoaderHandler = LoaderHandler(
151 check=lambda json_item, path=None, z=None: ( # type: ignore
152 isinstance(json_item, dict)
153 and "__format__" in json_item
154 and json_item["__format__"].startswith(_format)
155 ),
156 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore
157 uid=_format,
158 source_pckg=cls.__module__,
159 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
160 )
162 register_loader_handler(lh)
164 return lh
167_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
168_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
171class FieldIsNotInitOrSerializeWarning(UserWarning):
172 pass
175def SerializableDataclass__validate_field_type(
176 self: SerializableDataclass,
177 field: SerializableField | str,
178 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
179) -> bool:
180 """given a dataclass, check the field matches the type hint
182 this function is written to `SerializableDataclass.validate_field_type`
184 # Parameters:
185 - `self : SerializableDataclass`
186 `SerializableDataclass` instance
187 - `field : SerializableField | str`
188 field to validate, will get from `self.__dataclass_fields__` if an `str`
189 - `on_typecheck_error : ErrorMode`
190 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
191 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
193 # Returns:
194 - `bool`
195 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
196 """
197 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
199 # get field
200 _field: SerializableField
201 if isinstance(field, str):
202 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined]
203 else:
204 _field = field
206 # do nothing case
207 if not _field.assert_type:
208 return True
210 # if field is not `init` or not `serialize`, skip but warn
211 # TODO: how to handle fields which are not `init` or `serialize`?
212 if not _field.init or not _field.serialize:
213 warnings.warn(
214 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
215 FieldIsNotInitOrSerializeWarning,
216 )
217 return True
219 assert isinstance(
220 _field, SerializableField
221 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
223 # get field type hints
224 try:
225 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
226 except KeyError as e:
227 on_typecheck_error.process(
228 (
229 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
230 + f"{get_cls_type_hints(self.__class__) = }\n"
231 + f"Python version is {sys.version_info = }. You can:\n"
232 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n"
233 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
234 + " - use python 3.9.x or higher\n"
235 + " - specify custom type validation function via `custom_typecheck_fn`\n"
236 ),
237 except_cls=TypeError,
238 except_from=e,
239 )
240 return False
242 # get the value
243 value: Any = getattr(self, _field.name)
245 # validate the type
246 try:
247 type_is_valid: bool
248 # validate the type with the default type validator
249 if _field.custom_typecheck_fn is None:
250 type_is_valid = validate_type(value, field_type_hint)
251 # validate the type with a custom type validator
252 else:
253 type_is_valid = _field.custom_typecheck_fn(field_type_hint)
255 return type_is_valid
257 except Exception as e:
258 on_typecheck_error.process(
259 "exception while validating type: "
260 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
261 except_cls=ValueError,
262 except_from=e,
263 )
264 return False
267def SerializableDataclass__validate_fields_types__dict(
268 self: SerializableDataclass,
269 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
270) -> dict[str, bool]:
271 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
273 returns a dict of field names to bools, where the bool is if the field type is valid
274 """
275 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
277 # if except, bundle the exceptions
278 results: dict[str, bool] = dict()
279 exceptions: dict[str, Exception] = dict()
281 # for each field in the class
282 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment]
283 for field in cls_fields:
284 try:
285 results[field.name] = self.validate_field_type(field, on_typecheck_error)
286 except Exception as e:
287 results[field.name] = False
288 exceptions[field.name] = e
290 # figure out what to do with the exceptions
291 if len(exceptions) > 0:
292 on_typecheck_error.process(
293 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
294 + "\n\t"
295 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
296 except_cls=ValueError,
297 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
298 except_from=list(exceptions.values())[0],
299 )
301 return results
304def SerializableDataclass__validate_fields_types(
305 self: SerializableDataclass,
306 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
307) -> bool:
308 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
309 return all(
310 SerializableDataclass__validate_fields_types__dict(
311 self, on_typecheck_error=on_typecheck_error
312 ).values()
313 )
316@dataclass_transform(
317 field_specifiers=(serializable_field, SerializableField),
318)
319class SerializableDataclass(abc.ABC):
320 """Base class for serializable dataclasses
322 only for linting and type checking, still need to call `serializable_dataclass` decorator
324 # Usage:
326 ```python
327 @serializable_dataclass
328 class MyClass(SerializableDataclass):
329 a: int
330 b: str
331 ```
333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
335 >>> my_obj = MyClass(a=1, b="q")
336 >>> s = json.dumps(my_obj.serialize())
337 >>> s
338 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
339 >>> read_obj = MyClass.load(json.loads(s))
340 >>> read_obj == my_obj
341 True
343 This isn't too impressive on its own, but it gets more useful when you have nested classses,
344 or fields that are not json-serializable by default:
346 ```python
347 @serializable_dataclass
348 class NestedClass(SerializableDataclass):
349 x: str
350 y: MyClass
351 act_fun: torch.nn.Module = serializable_field(
352 default=torch.nn.ReLU(),
353 serialization_fn=lambda x: str(x),
354 deserialize_fn=lambda x: getattr(torch.nn, x)(),
355 )
356 ```
358 which gives us:
360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
361 >>> s = json.dumps(nc.serialize())
362 >>> s
363 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
364 >>> read_nc = NestedClass.load(json.loads(s))
365 >>> read_nc == nc
366 True
367 """
369 def serialize(self) -> dict[str, Any]:
370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371 raise NotImplementedError(
372 f"decorate {self.__class__ = } with `@serializable_dataclass`"
373 )
375 @classmethod
376 def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
380 def validate_fields_types(
381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382 ) -> bool:
383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384 return SerializableDataclass__validate_fields_types(
385 self, on_typecheck_error=on_typecheck_error
386 )
388 def validate_field_type(
389 self,
390 field: "SerializableField|str",
391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392 ) -> bool:
393 """given a dataclass, check the field matches the type hint"""
394 return SerializableDataclass__validate_field_type(
395 self, field, on_typecheck_error=on_typecheck_error
396 )
398 def __eq__(self, other: Any) -> bool:
399 return dc_eq(self, other)
401 def __hash__(self) -> int:
402 "hashes the json-serialized representation of the class"
403 return hash(json.dumps(self.serialize()))
405 def diff(
406 self, other: "SerializableDataclass", of_serialized: bool = False
407 ) -> dict[str, Any]:
408 """get a rich and recursive diff between two instances of a serializable dataclass
410 ```python
411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412 {'b': {'self': 2, 'other': 3}}
413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415 ```
417 # Parameters:
418 - `other : SerializableDataclass`
419 other instance to compare against
420 - `of_serialized : bool`
421 if true, compare serialized data and not raw values
422 (defaults to `False`)
424 # Returns:
425 - `dict[str, Any]`
428 # Raises:
429 - `ValueError` : if the instances are not of the same type
430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431 """
432 # match types
433 if type(self) is not type(other):
434 raise ValueError(
435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436 )
438 # initialize the diff result
439 diff_result: dict = {}
441 # if they are the same, return the empty diff
442 if self == other:
443 return diff_result
445 # if we are working with serialized data, serialize the instances
446 if of_serialized:
447 ser_self: dict = self.serialize()
448 ser_other: dict = other.serialize()
450 # for each field in the class
451 for field in dataclasses.fields(self): # type: ignore[arg-type]
452 # skip fields that are not for comparison
453 if not field.compare:
454 continue
456 # get values
457 field_name: str = field.name
458 self_value = getattr(self, field_name)
459 other_value = getattr(other, field_name)
461 # if the values are both serializable dataclasses, recurse
462 if isinstance(self_value, SerializableDataclass) and isinstance(
463 other_value, SerializableDataclass
464 ):
465 nested_diff: dict = self_value.diff(
466 other_value, of_serialized=of_serialized
467 )
468 if nested_diff:
469 diff_result[field_name] = nested_diff
470 # only support serializable dataclasses
471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472 other_value
473 ):
474 raise ValueError("Non-serializable dataclass is not supported")
475 else:
476 # get the values of either the serialized or the actual values
477 self_value_s = ser_self[field_name] if of_serialized else self_value
478 other_value_s = ser_other[field_name] if of_serialized else other_value
479 # compare the values
480 if not array_safe_eq(self_value_s, other_value_s):
481 diff_result[field_name] = {"self": self_value, "other": other_value}
483 # return the diff result
484 return diff_result
486 def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487 """update the instance from a nested dict, useful for configuration from command line args
489 # Parameters:
490 - `nested_dict : dict[str, Any]`
491 nested dict to update the instance with
492 """
493 for field in dataclasses.fields(self): # type: ignore[arg-type]
494 field_name: str = field.name
495 self_value = getattr(self, field_name)
497 if field_name in nested_dict:
498 if isinstance(self_value, SerializableDataclass):
499 self_value.update_from_nested_dict(nested_dict[field_name])
500 else:
501 setattr(self, field_name, nested_dict[field_name])
503 def __copy__(self) -> "SerializableDataclass":
504 "deep copy by serializing and loading the instance to json"
505 return self.__class__.load(json.loads(json.dumps(self.serialize())))
507 def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
508 "deep copy by serializing and loading the instance to json"
509 return self.__class__.load(json.loads(json.dumps(self.serialize())))
512# cache this so we don't have to keep getting it
513# TODO: are the types hashable? does this even make sense?
514@functools.lru_cache(typed=True)
515def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
516 "cached typing.get_type_hints for a class"
517 return typing.get_type_hints(cls)
520def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
521 "helper function to get type hints for a class"
522 cls_type_hints: dict[str, Any]
523 try:
524 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore
525 if len(cls_type_hints) == 0:
526 cls_type_hints = typing.get_type_hints(cls)
528 if len(cls_type_hints) == 0:
529 raise ValueError(f"empty type hints for {cls.__name__ = }")
530 except (TypeError, NameError, ValueError) as e:
531 raise TypeError(
532 f"Cannot get type hints for {cls = }\n"
533 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
534 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type]
535 + f" {e = }"
536 ) from e
538 return cls_type_hints
541class KWOnlyError(NotImplementedError):
542 "kw-only dataclasses are not supported in python <3.9"
544 pass
547class FieldError(ValueError):
548 "base class for field errors"
550 pass
553class NotSerializableFieldException(FieldError):
554 "field is not a `SerializableField`"
556 pass
559class FieldSerializationError(FieldError):
560 "error while serializing a field"
562 pass
565class FieldLoadingError(FieldError):
566 "error while loading a field"
568 pass
571class FieldTypeMismatchError(FieldError, TypeError):
572 "error when a field type does not match the type hint"
574 pass
577@dataclass_transform(
578 field_specifiers=(serializable_field, SerializableField),
579)
580def serializable_dataclass(
581 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
582 _cls=None, # type: ignore
583 *,
584 init: bool = True,
585 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
586 eq: bool = True,
587 order: bool = False,
588 unsafe_hash: bool = False,
589 frozen: bool = False,
590 properties_to_serialize: Optional[list[str]] = None,
591 register_handler: bool = True,
592 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
593 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
594 **kwargs,
595):
596 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
598 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
600 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
602 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
604 Examines PEP 526 `__annotations__` to determine fields.
606 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
608 ```python
609 @serializable_dataclass(kw_only=True)
610 class Myclass(SerializableDataclass):
611 a: int
612 b: str
613 ```
614 ```python
615 >>> Myclass(a=1, b="q").serialize()
616 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
617 ```
619 # Parameters:
620 - `_cls : _type_`
621 class to decorate. don't pass this arg, just use this as a decorator
622 (defaults to `None`)
623 - `init : bool`
624 (defaults to `True`)
625 - `repr : bool`
626 (defaults to `True`)
627 - `order : bool`
628 (defaults to `False`)
629 - `unsafe_hash : bool`
630 (defaults to `False`)
631 - `frozen : bool`
632 (defaults to `False`)
633 - `properties_to_serialize : Optional[list[str]]`
634 **SerializableDataclass only:** which properties to add to the serialized data dict
635 (defaults to `None`)
636 - `register_handler : bool`
637 **SerializableDataclass only:** if true, register the class with ZANJ for loading
638 (defaults to `True`)
639 - `on_typecheck_error : ErrorMode`
640 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
641 - `on_typecheck_mismatch : ErrorMode`
642 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
644 # Returns:
645 - `_type_`
646 the decorated class
648 # Raises:
649 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
650 - `NotSerializableFieldException` : if a field is not a `SerializableField`
651 - `FieldSerializationError` : if there is an error serializing a field
652 - `AttributeError` : if a property is not found on the class
653 - `FieldLoadingError` : if there is an error loading a field
654 """
655 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
656 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
657 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
659 if properties_to_serialize is None:
660 _properties_to_serialize: list = list()
661 else:
662 _properties_to_serialize = properties_to_serialize
664 def wrap(cls: Type[T]) -> Type[T]:
665 # Modify the __annotations__ dictionary to replace regular fields with SerializableField
666 for field_name, field_type in cls.__annotations__.items():
667 field_value = getattr(cls, field_name, None)
668 if not isinstance(field_value, SerializableField):
669 if isinstance(field_value, dataclasses.Field):
670 # Convert the field to a SerializableField while preserving properties
671 field_value = SerializableField.from_Field(field_value)
672 else:
673 # Create a new SerializableField
674 field_value = serializable_field()
675 setattr(cls, field_name, field_value)
677 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
678 if sys.version_info < (3, 10):
679 if "kw_only" in kwargs:
680 if kwargs["kw_only"] == True: # noqa: E712
681 raise KWOnlyError("kw_only is not supported in python >=3.9")
682 else:
683 del kwargs["kw_only"]
685 # call `dataclasses.dataclass` to set some stuff up
686 cls = dataclasses.dataclass( # type: ignore[call-overload]
687 cls,
688 init=init,
689 repr=repr,
690 eq=eq,
691 order=order,
692 unsafe_hash=unsafe_hash,
693 frozen=frozen,
694 **kwargs,
695 )
697 # copy these to the class
698 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined]
700 # ======================================================================
701 # define `serialize` func
702 # done locally since it depends on args to the decorator
703 # ======================================================================
704 def serialize(self) -> dict[str, Any]:
705 result: dict[str, Any] = {
706 "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
707 }
708 # for each field in the class
709 for field in dataclasses.fields(self): # type: ignore[arg-type]
710 # need it to be our special SerializableField
711 if not isinstance(field, SerializableField):
712 raise NotSerializableFieldException(
713 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
714 f"but a {type(field)} "
715 "this state should be inaccessible, please report this bug!"
716 )
718 # try to save it
719 if field.serialize:
720 try:
721 # get the val
722 value = getattr(self, field.name)
723 # if it is a serializable dataclass, serialize it
724 if isinstance(value, SerializableDataclass):
725 value = value.serialize()
726 # if the value has a serialization function, use that
727 if hasattr(value, "serialize") and callable(value.serialize):
728 value = value.serialize()
729 # if the field has a serialization function, use that
730 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
731 elif field.serialization_fn:
732 value = field.serialization_fn(value)
734 # store the value in the result
735 result[field.name] = value
736 except Exception as e:
737 raise FieldSerializationError(
738 "\n".join(
739 [
740 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
741 f"{field = }",
742 f"{value = }",
743 f"{self = }",
744 ]
745 )
746 ) from e
748 # store each property if we can get it
749 for prop in self._properties_to_serialize:
750 if hasattr(cls, prop):
751 value = getattr(self, prop)
752 result[prop] = value
753 else:
754 raise AttributeError(
755 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
756 + f"but it is in {self._properties_to_serialize = }"
757 + f"\n{self = }"
758 )
760 return result
762 # ======================================================================
763 # define `load` func
764 # done locally since it depends on args to the decorator
765 # ======================================================================
766 # mypy thinks this isnt a classmethod
767 @classmethod # type: ignore[misc]
768 def load(cls, data: dict[str, Any] | T) -> Type[T]:
769 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
770 if isinstance(data, cls):
771 return data
773 assert isinstance(
774 data, typing.Mapping
775 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
777 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
779 # initialize dict for keeping what we will pass to the constructor
780 ctor_kwargs: dict[str, Any] = dict()
782 # iterate over the fields of the class
783 for field in dataclasses.fields(cls):
784 # check if the field is a SerializableField
785 assert isinstance(
786 field, SerializableField
787 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
789 # check if the field is in the data and if it should be initialized
790 if (field.name in data) and field.init:
791 # get the value, we will be processing it
792 value: Any = data[field.name]
794 # get the type hint for the field
795 field_type_hint: Any = cls_type_hints.get(field.name, None)
797 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
798 if field.deserialize_fn:
799 # if it has a deserialization function, use that
800 value = field.deserialize_fn(value)
801 elif field.loading_fn:
802 # if it has a loading function, use that
803 value = field.loading_fn(data)
804 elif (
805 field_type_hint is not None
806 and hasattr(field_type_hint, "load")
807 and callable(field_type_hint.load)
808 ):
809 # if no loading function but has a type hint with a load method, use that
810 if isinstance(value, dict):
811 value = field_type_hint.load(value)
812 else:
813 raise FieldLoadingError(
814 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
815 )
816 else:
817 # assume no loading needs to happen, keep `value` as-is
818 pass
820 # store the value in the constructor kwargs
821 ctor_kwargs[field.name] = value
823 # create a new instance of the class with the constructor kwargs
824 output: cls = cls(**ctor_kwargs)
826 # validate the types of the fields if needed
827 if on_typecheck_mismatch != ErrorMode.IGNORE:
828 fields_valid: dict[str, bool] = (
829 SerializableDataclass__validate_fields_types__dict(
830 output,
831 on_typecheck_error=on_typecheck_error,
832 )
833 )
835 # if there are any fields that are not valid, raise an error
836 if not all(fields_valid.values()):
837 msg: str = (
838 f"Type mismatch in fields of {cls.__name__}:\n"
839 + "\n".join(
840 [
841 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
842 for k, v in fields_valid.items()
843 if not v
844 ]
845 )
846 )
848 on_typecheck_mismatch.process(
849 msg, except_cls=FieldTypeMismatchError
850 )
852 # return the new instance
853 return output
855 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
856 # type is `Callable[[T], dict]`
857 cls.serialize = serialize # type: ignore[attr-defined]
858 # type is `Callable[[dict], T]`
859 cls.load = load # type: ignore[attr-defined]
860 # type is `Callable[[T, ErrorMode], bool]`
861 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined]
863 # type is `Callable[[T, T], bool]`
864 if not hasattr(cls, "__eq__"):
865 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment]
867 # Register the class with ZANJ
868 if register_handler:
869 zanj_register_loader_serializable_dataclass(cls)
871 return cls
873 if _cls is None:
874 return wrap
875 else:
876 return wrap(_cls)