docs for muutils v0.7.0
View Source on GitHub

muutils.json_serialize.serializable_dataclass

save and load objects to and from json or compatible formats in a recoverable way

d = dataclasses.asdict(my_obj) will give you a dict, but if some fields are not json-serializable, you will get an error when you call json.dumps(d). This module provides a way around that.

Instead, you define your class:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True

  1"""save and load objects to and from json or compatible formats in a recoverable way
  2
  3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
  4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
  5
  6Instead, you define your class:
  7
  8```python
  9@serializable_dataclass
 10class MyClass(SerializableDataclass):
 11    a: int
 12    b: str
 13```
 14
 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
 16
 17    >>> my_obj = MyClass(a=1, b="q")
 18    >>> s = json.dumps(my_obj.serialize())
 19    >>> s
 20    '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
 21    >>> read_obj = MyClass.load(json.loads(s))
 22    >>> read_obj == my_obj
 23    True
 24
 25This isn't too impressive on its own, but it gets more useful when you have nested classses,
 26or fields that are not json-serializable by default:
 27
 28```python
 29@serializable_dataclass
 30class NestedClass(SerializableDataclass):
 31    x: str
 32    y: MyClass
 33    act_fun: torch.nn.Module = serializable_field(
 34        default=torch.nn.ReLU(),
 35        serialization_fn=lambda x: str(x),
 36        deserialize_fn=lambda x: getattr(torch.nn, x)(),
 37    )
 38```
 39
 40which gives us:
 41
 42    >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
 43    >>> s = json.dumps(nc.serialize())
 44    >>> s
 45    '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
 46    >>> read_nc = NestedClass.load(json.loads(s))
 47    >>> read_nc == nc
 48    True
 49
 50"""
 51
 52from __future__ import annotations
 53
 54import abc
 55import dataclasses
 56import functools
 57import json
 58import sys
 59import typing
 60import warnings
 61from typing import Any, Optional, Type, TypeVar
 62
 63from muutils.errormode import ErrorMode
 64from muutils.validate_type import validate_type
 65from muutils.json_serialize.serializable_field import (
 66    SerializableField,
 67    serializable_field,
 68)
 69from muutils.json_serialize.util import array_safe_eq, dc_eq
 70
 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
 72
 73
 74def _dataclass_transform_mock(
 75    *,
 76    eq_default: bool = True,
 77    order_default: bool = False,
 78    kw_only_default: bool = False,
 79    frozen_default: bool = False,
 80    field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (),
 81    **kwargs: Any,
 82) -> typing.Callable:
 83    "mock `typing.dataclass_transform` for python <3.11"
 84
 85    def decorator(cls_or_fn):
 86        cls_or_fn.__dataclass_transform__ = {
 87            "eq_default": eq_default,
 88            "order_default": order_default,
 89            "kw_only_default": kw_only_default,
 90            "frozen_default": frozen_default,
 91            "field_specifiers": field_specifiers,
 92            "kwargs": kwargs,
 93        }
 94        return cls_or_fn
 95
 96    return decorator
 97
 98
 99dataclass_transform: typing.Callable
100if sys.version_info < (3, 11):
101    dataclass_transform = _dataclass_transform_mock
102else:
103    dataclass_transform = typing.dataclass_transform
104
105
106T = TypeVar("T")
107
108
109class CantGetTypeHintsWarning(UserWarning):
110    "special warning for when we can't get type hints"
111
112    pass
113
114
115class ZanjMissingWarning(UserWarning):
116    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
117
118    pass
119
120
121_zanj_loading_needs_import: bool = True
122"flag to keep track of if we have successfully imported ZANJ"
123
124
125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
126    """Register a serializable dataclass with the ZANJ import
127
128    this allows `ZANJ().read()` to load the class and not just return plain dicts
129
130
131    # TODO: there is some duplication here with register_loader_handler
132    """
133    global _zanj_loading_needs_import
134
135    if _zanj_loading_needs_import:
136        try:
137            from zanj.loading import (  # type: ignore[import]
138                LoaderHandler,
139                register_loader_handler,
140            )
141        except ImportError:
142            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
143            # warnings.warn(
144            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
145            #     ZanjMissingWarning,
146            # )
147            return
148
149    _format: str = f"{cls.__name__}(SerializableDataclass)"
150    lh: LoaderHandler = LoaderHandler(
151        check=lambda json_item, path=None, z=None: (  # type: ignore
152            isinstance(json_item, dict)
153            and "__format__" in json_item
154            and json_item["__format__"].startswith(_format)
155        ),
156        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
157        uid=_format,
158        source_pckg=cls.__module__,
159        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
160    )
161
162    register_loader_handler(lh)
163
164    return lh
165
166
167_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
168_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
169
170
171class FieldIsNotInitOrSerializeWarning(UserWarning):
172    pass
173
174
175def SerializableDataclass__validate_field_type(
176    self: SerializableDataclass,
177    field: SerializableField | str,
178    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
179) -> bool:
180    """given a dataclass, check the field matches the type hint
181
182    this function is written to `SerializableDataclass.validate_field_type`
183
184    # Parameters:
185     - `self : SerializableDataclass`
186       `SerializableDataclass` instance
187     - `field : SerializableField | str`
188        field to validate, will get from `self.__dataclass_fields__` if an `str`
189     - `on_typecheck_error : ErrorMode`
190        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
191       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
192
193    # Returns:
194     - `bool`
195        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
196    """
197    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
198
199    # get field
200    _field: SerializableField
201    if isinstance(field, str):
202        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
203    else:
204        _field = field
205
206    # do nothing case
207    if not _field.assert_type:
208        return True
209
210    # if field is not `init` or not `serialize`, skip but warn
211    # TODO: how to handle fields which are not `init` or `serialize`?
212    if not _field.init or not _field.serialize:
213        warnings.warn(
214            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
215            FieldIsNotInitOrSerializeWarning,
216        )
217        return True
218
219    assert isinstance(
220        _field, SerializableField
221    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
222
223    # get field type hints
224    try:
225        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
226    except KeyError as e:
227        on_typecheck_error.process(
228            (
229                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
230                + f"{get_cls_type_hints(self.__class__) = }\n"
231                + f"Python version is {sys.version_info = }. You can:\n"
232                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
233                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
234                + "  - use python 3.9.x or higher\n"
235                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
236            ),
237            except_cls=TypeError,
238            except_from=e,
239        )
240        return False
241
242    # get the value
243    value: Any = getattr(self, _field.name)
244
245    # validate the type
246    try:
247        type_is_valid: bool
248        # validate the type with the default type validator
249        if _field.custom_typecheck_fn is None:
250            type_is_valid = validate_type(value, field_type_hint)
251        # validate the type with a custom type validator
252        else:
253            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
254
255        return type_is_valid
256
257    except Exception as e:
258        on_typecheck_error.process(
259            "exception while validating type: "
260            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
261            except_cls=ValueError,
262            except_from=e,
263        )
264        return False
265
266
267def SerializableDataclass__validate_fields_types__dict(
268    self: SerializableDataclass,
269    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
270) -> dict[str, bool]:
271    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
272
273    returns a dict of field names to bools, where the bool is if the field type is valid
274    """
275    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
276
277    # if except, bundle the exceptions
278    results: dict[str, bool] = dict()
279    exceptions: dict[str, Exception] = dict()
280
281    # for each field in the class
282    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
283    for field in cls_fields:
284        try:
285            results[field.name] = self.validate_field_type(field, on_typecheck_error)
286        except Exception as e:
287            results[field.name] = False
288            exceptions[field.name] = e
289
290    # figure out what to do with the exceptions
291    if len(exceptions) > 0:
292        on_typecheck_error.process(
293            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
294            + "\n\t"
295            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
296            except_cls=ValueError,
297            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
298            except_from=list(exceptions.values())[0],
299        )
300
301    return results
302
303
304def SerializableDataclass__validate_fields_types(
305    self: SerializableDataclass,
306    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
307) -> bool:
308    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
309    return all(
310        SerializableDataclass__validate_fields_types__dict(
311            self, on_typecheck_error=on_typecheck_error
312        ).values()
313    )
314
315
316@dataclass_transform(
317    field_specifiers=(serializable_field, SerializableField),
318)
319class SerializableDataclass(abc.ABC):
320    """Base class for serializable dataclasses
321
322    only for linting and type checking, still need to call `serializable_dataclass` decorator
323
324    # Usage:
325
326    ```python
327    @serializable_dataclass
328    class MyClass(SerializableDataclass):
329        a: int
330        b: str
331    ```
332
333    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334
335        >>> my_obj = MyClass(a=1, b="q")
336        >>> s = json.dumps(my_obj.serialize())
337        >>> s
338        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
339        >>> read_obj = MyClass.load(json.loads(s))
340        >>> read_obj == my_obj
341        True
342
343    This isn't too impressive on its own, but it gets more useful when you have nested classses,
344    or fields that are not json-serializable by default:
345
346    ```python
347    @serializable_dataclass
348    class NestedClass(SerializableDataclass):
349        x: str
350        y: MyClass
351        act_fun: torch.nn.Module = serializable_field(
352            default=torch.nn.ReLU(),
353            serialization_fn=lambda x: str(x),
354            deserialize_fn=lambda x: getattr(torch.nn, x)(),
355        )
356    ```
357
358    which gives us:
359
360        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
361        >>> s = json.dumps(nc.serialize())
362        >>> s
363        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
364        >>> read_nc = NestedClass.load(json.loads(s))
365        >>> read_nc == nc
366        True
367    """
368
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )
374
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )
387
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )
397
398    def __eq__(self, other: Any) -> bool:
399        return dc_eq(self, other)
400
401    def __hash__(self) -> int:
402        "hashes the json-serialized representation of the class"
403        return hash(json.dumps(self.serialize()))
404
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result
485
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])
502
503    def __copy__(self) -> "SerializableDataclass":
504        "deep copy by serializing and loading the instance to json"
505        return self.__class__.load(json.loads(json.dumps(self.serialize())))
506
507    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
508        "deep copy by serializing and loading the instance to json"
509        return self.__class__.load(json.loads(json.dumps(self.serialize())))
510
511
512# cache this so we don't have to keep getting it
513# TODO: are the types hashable? does this even make sense?
514@functools.lru_cache(typed=True)
515def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
516    "cached typing.get_type_hints for a class"
517    return typing.get_type_hints(cls)
518
519
520def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
521    "helper function to get type hints for a class"
522    cls_type_hints: dict[str, Any]
523    try:
524        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
525        if len(cls_type_hints) == 0:
526            cls_type_hints = typing.get_type_hints(cls)
527
528        if len(cls_type_hints) == 0:
529            raise ValueError(f"empty type hints for {cls.__name__ = }")
530    except (TypeError, NameError, ValueError) as e:
531        raise TypeError(
532            f"Cannot get type hints for {cls = }\n"
533            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
534            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
535            + f"  {e = }"
536        ) from e
537
538    return cls_type_hints
539
540
541class KWOnlyError(NotImplementedError):
542    "kw-only dataclasses are not supported in python <3.9"
543
544    pass
545
546
547class FieldError(ValueError):
548    "base class for field errors"
549
550    pass
551
552
553class NotSerializableFieldException(FieldError):
554    "field is not a `SerializableField`"
555
556    pass
557
558
559class FieldSerializationError(FieldError):
560    "error while serializing a field"
561
562    pass
563
564
565class FieldLoadingError(FieldError):
566    "error while loading a field"
567
568    pass
569
570
571class FieldTypeMismatchError(FieldError, TypeError):
572    "error when a field type does not match the type hint"
573
574    pass
575
576
577@dataclass_transform(
578    field_specifiers=(serializable_field, SerializableField),
579)
580def serializable_dataclass(
581    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
582    _cls=None,  # type: ignore
583    *,
584    init: bool = True,
585    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
586    eq: bool = True,
587    order: bool = False,
588    unsafe_hash: bool = False,
589    frozen: bool = False,
590    properties_to_serialize: Optional[list[str]] = None,
591    register_handler: bool = True,
592    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
593    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
594    **kwargs,
595):
596    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
597
598    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
599
600    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
601
602    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
603
604    Examines PEP 526 `__annotations__` to determine fields.
605
606    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
607
608    ```python
609    @serializable_dataclass(kw_only=True)
610    class Myclass(SerializableDataclass):
611        a: int
612        b: str
613    ```
614    ```python
615    >>> Myclass(a=1, b="q").serialize()
616    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
617    ```
618
619    # Parameters:
620     - `_cls : _type_`
621       class to decorate. don't pass this arg, just use this as a decorator
622       (defaults to `None`)
623     - `init : bool`
624       (defaults to `True`)
625     - `repr : bool`
626       (defaults to `True`)
627     - `order : bool`
628       (defaults to `False`)
629     - `unsafe_hash : bool`
630       (defaults to `False`)
631     - `frozen : bool`
632       (defaults to `False`)
633     - `properties_to_serialize : Optional[list[str]]`
634       **SerializableDataclass only:** which properties to add to the serialized data dict
635       (defaults to `None`)
636     - `register_handler : bool`
637        **SerializableDataclass only:** if true, register the class with ZANJ for loading
638       (defaults to `True`)
639     - `on_typecheck_error : ErrorMode`
640        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
641     - `on_typecheck_mismatch : ErrorMode`
642        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
643
644    # Returns:
645     - `_type_`
646       the decorated class
647
648    # Raises:
649     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
650     - `NotSerializableFieldException` : if a field is not a `SerializableField`
651     - `FieldSerializationError` : if there is an error serializing a field
652     - `AttributeError` : if a property is not found on the class
653     - `FieldLoadingError` : if there is an error loading a field
654    """
655    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
656    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
657    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
658
659    if properties_to_serialize is None:
660        _properties_to_serialize: list = list()
661    else:
662        _properties_to_serialize = properties_to_serialize
663
664    def wrap(cls: Type[T]) -> Type[T]:
665        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
666        for field_name, field_type in cls.__annotations__.items():
667            field_value = getattr(cls, field_name, None)
668            if not isinstance(field_value, SerializableField):
669                if isinstance(field_value, dataclasses.Field):
670                    # Convert the field to a SerializableField while preserving properties
671                    field_value = SerializableField.from_Field(field_value)
672                else:
673                    # Create a new SerializableField
674                    field_value = serializable_field()
675                setattr(cls, field_name, field_value)
676
677        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
678        if sys.version_info < (3, 10):
679            if "kw_only" in kwargs:
680                if kwargs["kw_only"] == True:  # noqa: E712
681                    raise KWOnlyError("kw_only is not supported in python >=3.9")
682                else:
683                    del kwargs["kw_only"]
684
685        # call `dataclasses.dataclass` to set some stuff up
686        cls = dataclasses.dataclass(  # type: ignore[call-overload]
687            cls,
688            init=init,
689            repr=repr,
690            eq=eq,
691            order=order,
692            unsafe_hash=unsafe_hash,
693            frozen=frozen,
694            **kwargs,
695        )
696
697        # copy these to the class
698        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
699
700        # ======================================================================
701        # define `serialize` func
702        # done locally since it depends on args to the decorator
703        # ======================================================================
704        def serialize(self) -> dict[str, Any]:
705            result: dict[str, Any] = {
706                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
707            }
708            # for each field in the class
709            for field in dataclasses.fields(self):  # type: ignore[arg-type]
710                # need it to be our special SerializableField
711                if not isinstance(field, SerializableField):
712                    raise NotSerializableFieldException(
713                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
714                        f"but a {type(field)} "
715                        "this state should be inaccessible, please report this bug!"
716                    )
717
718                # try to save it
719                if field.serialize:
720                    try:
721                        # get the val
722                        value = getattr(self, field.name)
723                        # if it is a serializable dataclass, serialize it
724                        if isinstance(value, SerializableDataclass):
725                            value = value.serialize()
726                        # if the value has a serialization function, use that
727                        if hasattr(value, "serialize") and callable(value.serialize):
728                            value = value.serialize()
729                        # if the field has a serialization function, use that
730                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
731                        elif field.serialization_fn:
732                            value = field.serialization_fn(value)
733
734                        # store the value in the result
735                        result[field.name] = value
736                    except Exception as e:
737                        raise FieldSerializationError(
738                            "\n".join(
739                                [
740                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
741                                    f"{field = }",
742                                    f"{value = }",
743                                    f"{self = }",
744                                ]
745                            )
746                        ) from e
747
748            # store each property if we can get it
749            for prop in self._properties_to_serialize:
750                if hasattr(cls, prop):
751                    value = getattr(self, prop)
752                    result[prop] = value
753                else:
754                    raise AttributeError(
755                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
756                        + f"but it is in {self._properties_to_serialize = }"
757                        + f"\n{self = }"
758                    )
759
760            return result
761
762        # ======================================================================
763        # define `load` func
764        # done locally since it depends on args to the decorator
765        # ======================================================================
766        # mypy thinks this isnt a classmethod
767        @classmethod  # type: ignore[misc]
768        def load(cls, data: dict[str, Any] | T) -> Type[T]:
769            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
770            if isinstance(data, cls):
771                return data
772
773            assert isinstance(
774                data, typing.Mapping
775            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
776
777            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
778
779            # initialize dict for keeping what we will pass to the constructor
780            ctor_kwargs: dict[str, Any] = dict()
781
782            # iterate over the fields of the class
783            for field in dataclasses.fields(cls):
784                # check if the field is a SerializableField
785                assert isinstance(
786                    field, SerializableField
787                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
788
789                # check if the field is in the data and if it should be initialized
790                if (field.name in data) and field.init:
791                    # get the value, we will be processing it
792                    value: Any = data[field.name]
793
794                    # get the type hint for the field
795                    field_type_hint: Any = cls_type_hints.get(field.name, None)
796
797                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
798                    if field.deserialize_fn:
799                        # if it has a deserialization function, use that
800                        value = field.deserialize_fn(value)
801                    elif field.loading_fn:
802                        # if it has a loading function, use that
803                        value = field.loading_fn(data)
804                    elif (
805                        field_type_hint is not None
806                        and hasattr(field_type_hint, "load")
807                        and callable(field_type_hint.load)
808                    ):
809                        # if no loading function but has a type hint with a load method, use that
810                        if isinstance(value, dict):
811                            value = field_type_hint.load(value)
812                        else:
813                            raise FieldLoadingError(
814                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
815                            )
816                    else:
817                        # assume no loading needs to happen, keep `value` as-is
818                        pass
819
820                    # store the value in the constructor kwargs
821                    ctor_kwargs[field.name] = value
822
823            # create a new instance of the class with the constructor kwargs
824            output: cls = cls(**ctor_kwargs)
825
826            # validate the types of the fields if needed
827            if on_typecheck_mismatch != ErrorMode.IGNORE:
828                fields_valid: dict[str, bool] = (
829                    SerializableDataclass__validate_fields_types__dict(
830                        output,
831                        on_typecheck_error=on_typecheck_error,
832                    )
833                )
834
835                # if there are any fields that are not valid, raise an error
836                if not all(fields_valid.values()):
837                    msg: str = (
838                        f"Type mismatch in fields of {cls.__name__}:\n"
839                        + "\n".join(
840                            [
841                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
842                                for k, v in fields_valid.items()
843                                if not v
844                            ]
845                        )
846                    )
847
848                    on_typecheck_mismatch.process(
849                        msg, except_cls=FieldTypeMismatchError
850                    )
851
852            # return the new instance
853            return output
854
855        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
856        # type is `Callable[[T], dict]`
857        cls.serialize = serialize  # type: ignore[attr-defined]
858        # type is `Callable[[dict], T]`
859        cls.load = load  # type: ignore[attr-defined]
860        # type is `Callable[[T, ErrorMode], bool]`
861        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
862
863        # type is `Callable[[T, T], bool]`
864        if not hasattr(cls, "__eq__"):
865            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
866
867        # Register the class with ZANJ
868        if register_handler:
869            zanj_register_loader_serializable_dataclass(cls)
870
871        return cls
872
873    if _cls is None:
874        return wrap
875    else:
876        return wrap(_cls)

def dataclass_transform( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, frozen_default: bool = False, field_specifiers: tuple[typing.Union[type[typing.Any], typing.Callable[..., typing.Any]], ...] = (), **kwargs: Any) -> <class '_IdentityCallable'>:
3275def dataclass_transform(
3276    *,
3277    eq_default: bool = True,
3278    order_default: bool = False,
3279    kw_only_default: bool = False,
3280    frozen_default: bool = False,
3281    field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (),
3282    **kwargs: Any,
3283) -> _IdentityCallable:
3284    """Decorator to mark an object as providing dataclass-like behaviour.
3285
3286    The decorator can be applied to a function, class, or metaclass.
3287
3288    Example usage with a decorator function::
3289
3290        @dataclass_transform()
3291        def create_model[T](cls: type[T]) -> type[T]:
3292            ...
3293            return cls
3294
3295        @create_model
3296        class CustomerModel:
3297            id: int
3298            name: str
3299
3300    On a base class::
3301
3302        @dataclass_transform()
3303        class ModelBase: ...
3304
3305        class CustomerModel(ModelBase):
3306            id: int
3307            name: str
3308
3309    On a metaclass::
3310
3311        @dataclass_transform()
3312        class ModelMeta(type): ...
3313
3314        class ModelBase(metaclass=ModelMeta): ...
3315
3316        class CustomerModel(ModelBase):
3317            id: int
3318            name: str
3319
3320    The ``CustomerModel`` classes defined above will
3321    be treated by type checkers similarly to classes created with
3322    ``@dataclasses.dataclass``.
3323    For example, type checkers will assume these classes have
3324    ``__init__`` methods that accept ``id`` and ``name``.
3325
3326    The arguments to this decorator can be used to customize this behavior:
3327    - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be
3328        ``True`` or ``False`` if it is omitted by the caller.
3329    - ``order_default`` indicates whether the ``order`` parameter is
3330        assumed to be True or False if it is omitted by the caller.
3331    - ``kw_only_default`` indicates whether the ``kw_only`` parameter is
3332        assumed to be True or False if it is omitted by the caller.
3333    - ``frozen_default`` indicates whether the ``frozen`` parameter is
3334        assumed to be True or False if it is omitted by the caller.
3335    - ``field_specifiers`` specifies a static list of supported classes
3336        or functions that describe fields, similar to ``dataclasses.field()``.
3337    - Arbitrary other keyword arguments are accepted in order to allow for
3338        possible future extensions.
3339
3340    At runtime, this decorator records its arguments in the
3341    ``__dataclass_transform__`` attribute on the decorated object.
3342    It has no other runtime effect.
3343
3344    See PEP 681 for more details.
3345    """
3346    def decorator(cls_or_fn):
3347        cls_or_fn.__dataclass_transform__ = {
3348            "eq_default": eq_default,
3349            "order_default": order_default,
3350            "kw_only_default": kw_only_default,
3351            "frozen_default": frozen_default,
3352            "field_specifiers": field_specifiers,
3353            "kwargs": kwargs,
3354        }
3355        return cls_or_fn
3356    return decorator

Decorator to mark an object as providing dataclass-like behaviour.

The decorator can be applied to a function, class, or metaclass.

Example usage with a decorator function::

@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
    ...
    return cls

@create_model
class CustomerModel:
    id: int
    name: str

On a base class::

@dataclass_transform()
class ModelBase: ...

class CustomerModel(ModelBase):
    id: int
    name: str

On a metaclass::

@dataclass_transform()
class ModelMeta(type): ...

class ModelBase(metaclass=ModelMeta): ...

class CustomerModel(ModelBase):
    id: int
    name: str

The CustomerModel classes defined above will be treated by type checkers similarly to classes created with @dataclasses.dataclass. For example, type checkers will assume these classes have __init__ methods that accept id and name.

The arguments to this decorator can be used to customize this behavior:

  • eq_default indicates whether the eq parameter is assumed to be True or False if it is omitted by the caller.
  • order_default indicates whether the order parameter is assumed to be True or False if it is omitted by the caller.
  • kw_only_default indicates whether the kw_only parameter is assumed to be True or False if it is omitted by the caller.
  • frozen_default indicates whether the frozen parameter is assumed to be True or False if it is omitted by the caller.
  • field_specifiers specifies a static list of supported classes or functions that describe fields, similar to dataclasses.field().
  • Arbitrary other keyword arguments are accepted in order to allow for possible future extensions.

At runtime, this decorator records its arguments in the __dataclass_transform__ attribute on the decorated object. It has no other runtime effect.

See PEP 681 for more details.

class CantGetTypeHintsWarning(builtins.UserWarning):
110class CantGetTypeHintsWarning(UserWarning):
111    "special warning for when we can't get type hints"
112
113    pass

special warning for when we can't get type hints

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
class ZanjMissingWarning(builtins.UserWarning):
116class ZanjMissingWarning(UserWarning):
117    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
118
119    pass

special warning for when ZANJ is missing -- register_loader_serializable_dataclass will not work

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def zanj_register_loader_serializable_dataclass(cls: Type[~T]):
126def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
127    """Register a serializable dataclass with the ZANJ import
128
129    this allows `ZANJ().read()` to load the class and not just return plain dicts
130
131
132    # TODO: there is some duplication here with register_loader_handler
133    """
134    global _zanj_loading_needs_import
135
136    if _zanj_loading_needs_import:
137        try:
138            from zanj.loading import (  # type: ignore[import]
139                LoaderHandler,
140                register_loader_handler,
141            )
142        except ImportError:
143            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
144            # warnings.warn(
145            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
146            #     ZanjMissingWarning,
147            # )
148            return
149
150    _format: str = f"{cls.__name__}(SerializableDataclass)"
151    lh: LoaderHandler = LoaderHandler(
152        check=lambda json_item, path=None, z=None: (  # type: ignore
153            isinstance(json_item, dict)
154            and "__format__" in json_item
155            and json_item["__format__"].startswith(_format)
156        ),
157        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
158        uid=_format,
159        source_pckg=cls.__module__,
160        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
161    )
162
163    register_loader_handler(lh)
164
165    return lh

Register a serializable dataclass with the ZANJ import

this allows ZANJ().read() to load the class and not just return plain dicts

TODO: there is some duplication here with register_loader_handler

class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):
172class FieldIsNotInitOrSerializeWarning(UserWarning):
173    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
176def SerializableDataclass__validate_field_type(
177    self: SerializableDataclass,
178    field: SerializableField | str,
179    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
180) -> bool:
181    """given a dataclass, check the field matches the type hint
182
183    this function is written to `SerializableDataclass.validate_field_type`
184
185    # Parameters:
186     - `self : SerializableDataclass`
187       `SerializableDataclass` instance
188     - `field : SerializableField | str`
189        field to validate, will get from `self.__dataclass_fields__` if an `str`
190     - `on_typecheck_error : ErrorMode`
191        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
192       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
193
194    # Returns:
195     - `bool`
196        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
197    """
198    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
199
200    # get field
201    _field: SerializableField
202    if isinstance(field, str):
203        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
204    else:
205        _field = field
206
207    # do nothing case
208    if not _field.assert_type:
209        return True
210
211    # if field is not `init` or not `serialize`, skip but warn
212    # TODO: how to handle fields which are not `init` or `serialize`?
213    if not _field.init or not _field.serialize:
214        warnings.warn(
215            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
216            FieldIsNotInitOrSerializeWarning,
217        )
218        return True
219
220    assert isinstance(
221        _field, SerializableField
222    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
223
224    # get field type hints
225    try:
226        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
227    except KeyError as e:
228        on_typecheck_error.process(
229            (
230                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
231                + f"{get_cls_type_hints(self.__class__) = }\n"
232                + f"Python version is {sys.version_info = }. You can:\n"
233                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
234                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
235                + "  - use python 3.9.x or higher\n"
236                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
237            ),
238            except_cls=TypeError,
239            except_from=e,
240        )
241        return False
242
243    # get the value
244    value: Any = getattr(self, _field.name)
245
246    # validate the type
247    try:
248        type_is_valid: bool
249        # validate the type with the default type validator
250        if _field.custom_typecheck_fn is None:
251            type_is_valid = validate_type(value, field_type_hint)
252        # validate the type with a custom type validator
253        else:
254            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
255
256        return type_is_valid
257
258    except Exception as e:
259        on_typecheck_error.process(
260            "exception while validating type: "
261            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
262            except_cls=ValueError,
263            except_from=e,
264        )
265        return False

given a dataclass, check the field matches the type hint

this function is written to SerializableDataclass.validate_field_type

Parameters:

  • self : SerializableDataclass SerializableDataclass instance
  • field : SerializableField | str field to validate, will get from self.__dataclass_fields__ if an str
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, the function will return False (defaults to _DEFAULT_ON_TYPECHECK_ERROR)

Returns:

  • bool if the field type is correct. False if the field type is incorrect or an exception is thrown and on_typecheck_error is ignore
def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> dict[str, bool]:
268def SerializableDataclass__validate_fields_types__dict(
269    self: SerializableDataclass,
270    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
271) -> dict[str, bool]:
272    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
273
274    returns a dict of field names to bools, where the bool is if the field type is valid
275    """
276    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
277
278    # if except, bundle the exceptions
279    results: dict[str, bool] = dict()
280    exceptions: dict[str, Exception] = dict()
281
282    # for each field in the class
283    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
284    for field in cls_fields:
285        try:
286            results[field.name] = self.validate_field_type(field, on_typecheck_error)
287        except Exception as e:
288            results[field.name] = False
289            exceptions[field.name] = e
290
291    # figure out what to do with the exceptions
292    if len(exceptions) > 0:
293        on_typecheck_error.process(
294            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
295            + "\n\t"
296            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
297            except_cls=ValueError,
298            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
299            except_from=list(exceptions.values())[0],
300        )
301
302    return results

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

returns a dict of field names to bools, where the bool is if the field type is valid

def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
305def SerializableDataclass__validate_fields_types(
306    self: SerializableDataclass,
307    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
308) -> bool:
309    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
310    return all(
311        SerializableDataclass__validate_fields_types__dict(
312            self, on_typecheck_error=on_typecheck_error
313        ).values()
314    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
317@dataclass_transform(
318    field_specifiers=(serializable_field, SerializableField),
319)
320class SerializableDataclass(abc.ABC):
321    """Base class for serializable dataclasses
322
323    only for linting and type checking, still need to call `serializable_dataclass` decorator
324
325    # Usage:
326
327    ```python
328    @serializable_dataclass
329    class MyClass(SerializableDataclass):
330        a: int
331        b: str
332    ```
333
334    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
335
336        >>> my_obj = MyClass(a=1, b="q")
337        >>> s = json.dumps(my_obj.serialize())
338        >>> s
339        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
340        >>> read_obj = MyClass.load(json.loads(s))
341        >>> read_obj == my_obj
342        True
343
344    This isn't too impressive on its own, but it gets more useful when you have nested classses,
345    or fields that are not json-serializable by default:
346
347    ```python
348    @serializable_dataclass
349    class NestedClass(SerializableDataclass):
350        x: str
351        y: MyClass
352        act_fun: torch.nn.Module = serializable_field(
353            default=torch.nn.ReLU(),
354            serialization_fn=lambda x: str(x),
355            deserialize_fn=lambda x: getattr(torch.nn, x)(),
356        )
357    ```
358
359    which gives us:
360
361        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
362        >>> s = json.dumps(nc.serialize())
363        >>> s
364        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
365        >>> read_nc = NestedClass.load(json.loads(s))
366        >>> read_nc == nc
367        True
368    """
369
370    def serialize(self) -> dict[str, Any]:
371        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
372        raise NotImplementedError(
373            f"decorate {self.__class__ = } with `@serializable_dataclass`"
374        )
375
376    @classmethod
377    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
378        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
379        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
380
381    def validate_fields_types(
382        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
383    ) -> bool:
384        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
385        return SerializableDataclass__validate_fields_types(
386            self, on_typecheck_error=on_typecheck_error
387        )
388
389    def validate_field_type(
390        self,
391        field: "SerializableField|str",
392        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
393    ) -> bool:
394        """given a dataclass, check the field matches the type hint"""
395        return SerializableDataclass__validate_field_type(
396            self, field, on_typecheck_error=on_typecheck_error
397        )
398
399    def __eq__(self, other: Any) -> bool:
400        return dc_eq(self, other)
401
402    def __hash__(self) -> int:
403        "hashes the json-serialized representation of the class"
404        return hash(json.dumps(self.serialize()))
405
406    def diff(
407        self, other: "SerializableDataclass", of_serialized: bool = False
408    ) -> dict[str, Any]:
409        """get a rich and recursive diff between two instances of a serializable dataclass
410
411        ```python
412        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
413        {'b': {'self': 2, 'other': 3}}
414        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
415        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
416        ```
417
418        # Parameters:
419         - `other : SerializableDataclass`
420           other instance to compare against
421         - `of_serialized : bool`
422           if true, compare serialized data and not raw values
423           (defaults to `False`)
424
425        # Returns:
426         - `dict[str, Any]`
427
428
429        # Raises:
430         - `ValueError` : if the instances are not of the same type
431         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
432        """
433        # match types
434        if type(self) is not type(other):
435            raise ValueError(
436                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
437            )
438
439        # initialize the diff result
440        diff_result: dict = {}
441
442        # if they are the same, return the empty diff
443        if self == other:
444            return diff_result
445
446        # if we are working with serialized data, serialize the instances
447        if of_serialized:
448            ser_self: dict = self.serialize()
449            ser_other: dict = other.serialize()
450
451        # for each field in the class
452        for field in dataclasses.fields(self):  # type: ignore[arg-type]
453            # skip fields that are not for comparison
454            if not field.compare:
455                continue
456
457            # get values
458            field_name: str = field.name
459            self_value = getattr(self, field_name)
460            other_value = getattr(other, field_name)
461
462            # if the values are both serializable dataclasses, recurse
463            if isinstance(self_value, SerializableDataclass) and isinstance(
464                other_value, SerializableDataclass
465            ):
466                nested_diff: dict = self_value.diff(
467                    other_value, of_serialized=of_serialized
468                )
469                if nested_diff:
470                    diff_result[field_name] = nested_diff
471            # only support serializable dataclasses
472            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
473                other_value
474            ):
475                raise ValueError("Non-serializable dataclass is not supported")
476            else:
477                # get the values of either the serialized or the actual values
478                self_value_s = ser_self[field_name] if of_serialized else self_value
479                other_value_s = ser_other[field_name] if of_serialized else other_value
480                # compare the values
481                if not array_safe_eq(self_value_s, other_value_s):
482                    diff_result[field_name] = {"self": self_value, "other": other_value}
483
484        # return the diff result
485        return diff_result
486
487    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
488        """update the instance from a nested dict, useful for configuration from command line args
489
490        # Parameters:
491            - `nested_dict : dict[str, Any]`
492                nested dict to update the instance with
493        """
494        for field in dataclasses.fields(self):  # type: ignore[arg-type]
495            field_name: str = field.name
496            self_value = getattr(self, field_name)
497
498            if field_name in nested_dict:
499                if isinstance(self_value, SerializableDataclass):
500                    self_value.update_from_nested_dict(nested_dict[field_name])
501                else:
502                    setattr(self, field_name, nested_dict[field_name])
503
504    def __copy__(self) -> "SerializableDataclass":
505        "deep copy by serializing and loading the instance to json"
506        return self.__class__.load(json.loads(json.dumps(self.serialize())))
507
508    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
509        "deep copy by serializing and loading the instance to json"
510        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
370    def serialize(self) -> dict[str, Any]:
371        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
372        raise NotImplementedError(
373            f"decorate {self.__class__ = } with `@serializable_dataclass`"
374        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
376    @classmethod
377    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
378        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
379        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
381    def validate_fields_types(
382        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
383    ) -> bool:
384        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
385        return SerializableDataclass__validate_fields_types(
386            self, on_typecheck_error=on_typecheck_error
387        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
389    def validate_field_type(
390        self,
391        field: "SerializableField|str",
392        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
393    ) -> bool:
394        """given a dataclass, check the field matches the type hint"""
395        return SerializableDataclass__validate_field_type(
396            self, field, on_typecheck_error=on_typecheck_error
397        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
406    def diff(
407        self, other: "SerializableDataclass", of_serialized: bool = False
408    ) -> dict[str, Any]:
409        """get a rich and recursive diff between two instances of a serializable dataclass
410
411        ```python
412        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
413        {'b': {'self': 2, 'other': 3}}
414        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
415        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
416        ```
417
418        # Parameters:
419         - `other : SerializableDataclass`
420           other instance to compare against
421         - `of_serialized : bool`
422           if true, compare serialized data and not raw values
423           (defaults to `False`)
424
425        # Returns:
426         - `dict[str, Any]`
427
428
429        # Raises:
430         - `ValueError` : if the instances are not of the same type
431         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
432        """
433        # match types
434        if type(self) is not type(other):
435            raise ValueError(
436                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
437            )
438
439        # initialize the diff result
440        diff_result: dict = {}
441
442        # if they are the same, return the empty diff
443        if self == other:
444            return diff_result
445
446        # if we are working with serialized data, serialize the instances
447        if of_serialized:
448            ser_self: dict = self.serialize()
449            ser_other: dict = other.serialize()
450
451        # for each field in the class
452        for field in dataclasses.fields(self):  # type: ignore[arg-type]
453            # skip fields that are not for comparison
454            if not field.compare:
455                continue
456
457            # get values
458            field_name: str = field.name
459            self_value = getattr(self, field_name)
460            other_value = getattr(other, field_name)
461
462            # if the values are both serializable dataclasses, recurse
463            if isinstance(self_value, SerializableDataclass) and isinstance(
464                other_value, SerializableDataclass
465            ):
466                nested_diff: dict = self_value.diff(
467                    other_value, of_serialized=of_serialized
468                )
469                if nested_diff:
470                    diff_result[field_name] = nested_diff
471            # only support serializable dataclasses
472            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
473                other_value
474            ):
475                raise ValueError("Non-serializable dataclass is not supported")
476            else:
477                # get the values of either the serialized or the actual values
478                self_value_s = ser_self[field_name] if of_serialized else self_value
479                other_value_s = ser_other[field_name] if of_serialized else other_value
480                # compare the values
481                if not array_safe_eq(self_value_s, other_value_s):
482                    diff_result[field_name] = {"self": self_value, "other": other_value}
483
484        # return the diff result
485        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
487    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
488        """update the instance from a nested dict, useful for configuration from command line args
489
490        # Parameters:
491            - `nested_dict : dict[str, Any]`
492                nested dict to update the instance with
493        """
494        for field in dataclasses.fields(self):  # type: ignore[arg-type]
495            field_name: str = field.name
496            self_value = getattr(self, field_name)
497
498            if field_name in nested_dict:
499                if isinstance(self_value, SerializableDataclass):
500                    self_value.update_from_nested_dict(nested_dict[field_name])
501                else:
502                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with
@functools.lru_cache(typed=True)
def get_cls_type_hints_cached(cls: Type[~T]) -> dict[str, typing.Any]:
515@functools.lru_cache(typed=True)
516def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
517    "cached typing.get_type_hints for a class"
518    return typing.get_type_hints(cls)

cached typing.get_type_hints for a class

def get_cls_type_hints(cls: Type[~T]) -> dict[str, typing.Any]:
521def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
522    "helper function to get type hints for a class"
523    cls_type_hints: dict[str, Any]
524    try:
525        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
526        if len(cls_type_hints) == 0:
527            cls_type_hints = typing.get_type_hints(cls)
528
529        if len(cls_type_hints) == 0:
530            raise ValueError(f"empty type hints for {cls.__name__ = }")
531    except (TypeError, NameError, ValueError) as e:
532        raise TypeError(
533            f"Cannot get type hints for {cls = }\n"
534            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
535            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
536            + f"  {e = }"
537        ) from e
538
539    return cls_type_hints

helper function to get type hints for a class

class KWOnlyError(builtins.NotImplementedError):
542class KWOnlyError(NotImplementedError):
543    "kw-only dataclasses are not supported in python <3.9"
544
545    pass

kw-only dataclasses are not supported in python <3.9

Inherited Members
builtins.NotImplementedError
NotImplementedError
builtins.BaseException
with_traceback
add_note
args
class FieldError(builtins.ValueError):
548class FieldError(ValueError):
549    "base class for field errors"
550
551    pass

base class for field errors

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class NotSerializableFieldException(FieldError):
554class NotSerializableFieldException(FieldError):
555    "field is not a `SerializableField`"
556
557    pass

field is not a SerializableField

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldSerializationError(FieldError):
560class FieldSerializationError(FieldError):
561    "error while serializing a field"
562
563    pass

error while serializing a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldLoadingError(FieldError):
566class FieldLoadingError(FieldError):
567    "error while loading a field"
568
569    pass

error while loading a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldTypeMismatchError(FieldError, builtins.TypeError):
572class FieldTypeMismatchError(FieldError, TypeError):
573    "error when a field type does not match the type hint"
574
575    pass

error when a field type does not match the type hint

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, **kwargs):
578@dataclass_transform(
579    field_specifiers=(serializable_field, SerializableField),
580)
581def serializable_dataclass(
582    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
583    _cls=None,  # type: ignore
584    *,
585    init: bool = True,
586    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
587    eq: bool = True,
588    order: bool = False,
589    unsafe_hash: bool = False,
590    frozen: bool = False,
591    properties_to_serialize: Optional[list[str]] = None,
592    register_handler: bool = True,
593    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
594    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
595    **kwargs,
596):
597    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
598
599    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
600
601    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
602
603    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
604
605    Examines PEP 526 `__annotations__` to determine fields.
606
607    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
608
609    ```python
610    @serializable_dataclass(kw_only=True)
611    class Myclass(SerializableDataclass):
612        a: int
613        b: str
614    ```
615    ```python
616    >>> Myclass(a=1, b="q").serialize()
617    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
618    ```
619
620    # Parameters:
621     - `_cls : _type_`
622       class to decorate. don't pass this arg, just use this as a decorator
623       (defaults to `None`)
624     - `init : bool`
625       (defaults to `True`)
626     - `repr : bool`
627       (defaults to `True`)
628     - `order : bool`
629       (defaults to `False`)
630     - `unsafe_hash : bool`
631       (defaults to `False`)
632     - `frozen : bool`
633       (defaults to `False`)
634     - `properties_to_serialize : Optional[list[str]]`
635       **SerializableDataclass only:** which properties to add to the serialized data dict
636       (defaults to `None`)
637     - `register_handler : bool`
638        **SerializableDataclass only:** if true, register the class with ZANJ for loading
639       (defaults to `True`)
640     - `on_typecheck_error : ErrorMode`
641        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
642     - `on_typecheck_mismatch : ErrorMode`
643        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
644
645    # Returns:
646     - `_type_`
647       the decorated class
648
649    # Raises:
650     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
651     - `NotSerializableFieldException` : if a field is not a `SerializableField`
652     - `FieldSerializationError` : if there is an error serializing a field
653     - `AttributeError` : if a property is not found on the class
654     - `FieldLoadingError` : if there is an error loading a field
655    """
656    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
657    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
658    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
659
660    if properties_to_serialize is None:
661        _properties_to_serialize: list = list()
662    else:
663        _properties_to_serialize = properties_to_serialize
664
665    def wrap(cls: Type[T]) -> Type[T]:
666        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
667        for field_name, field_type in cls.__annotations__.items():
668            field_value = getattr(cls, field_name, None)
669            if not isinstance(field_value, SerializableField):
670                if isinstance(field_value, dataclasses.Field):
671                    # Convert the field to a SerializableField while preserving properties
672                    field_value = SerializableField.from_Field(field_value)
673                else:
674                    # Create a new SerializableField
675                    field_value = serializable_field()
676                setattr(cls, field_name, field_value)
677
678        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
679        if sys.version_info < (3, 10):
680            if "kw_only" in kwargs:
681                if kwargs["kw_only"] == True:  # noqa: E712
682                    raise KWOnlyError("kw_only is not supported in python >=3.9")
683                else:
684                    del kwargs["kw_only"]
685
686        # call `dataclasses.dataclass` to set some stuff up
687        cls = dataclasses.dataclass(  # type: ignore[call-overload]
688            cls,
689            init=init,
690            repr=repr,
691            eq=eq,
692            order=order,
693            unsafe_hash=unsafe_hash,
694            frozen=frozen,
695            **kwargs,
696        )
697
698        # copy these to the class
699        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
700
701        # ======================================================================
702        # define `serialize` func
703        # done locally since it depends on args to the decorator
704        # ======================================================================
705        def serialize(self) -> dict[str, Any]:
706            result: dict[str, Any] = {
707                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
708            }
709            # for each field in the class
710            for field in dataclasses.fields(self):  # type: ignore[arg-type]
711                # need it to be our special SerializableField
712                if not isinstance(field, SerializableField):
713                    raise NotSerializableFieldException(
714                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
715                        f"but a {type(field)} "
716                        "this state should be inaccessible, please report this bug!"
717                    )
718
719                # try to save it
720                if field.serialize:
721                    try:
722                        # get the val
723                        value = getattr(self, field.name)
724                        # if it is a serializable dataclass, serialize it
725                        if isinstance(value, SerializableDataclass):
726                            value = value.serialize()
727                        # if the value has a serialization function, use that
728                        if hasattr(value, "serialize") and callable(value.serialize):
729                            value = value.serialize()
730                        # if the field has a serialization function, use that
731                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
732                        elif field.serialization_fn:
733                            value = field.serialization_fn(value)
734
735                        # store the value in the result
736                        result[field.name] = value
737                    except Exception as e:
738                        raise FieldSerializationError(
739                            "\n".join(
740                                [
741                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
742                                    f"{field = }",
743                                    f"{value = }",
744                                    f"{self = }",
745                                ]
746                            )
747                        ) from e
748
749            # store each property if we can get it
750            for prop in self._properties_to_serialize:
751                if hasattr(cls, prop):
752                    value = getattr(self, prop)
753                    result[prop] = value
754                else:
755                    raise AttributeError(
756                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
757                        + f"but it is in {self._properties_to_serialize = }"
758                        + f"\n{self = }"
759                    )
760
761            return result
762
763        # ======================================================================
764        # define `load` func
765        # done locally since it depends on args to the decorator
766        # ======================================================================
767        # mypy thinks this isnt a classmethod
768        @classmethod  # type: ignore[misc]
769        def load(cls, data: dict[str, Any] | T) -> Type[T]:
770            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
771            if isinstance(data, cls):
772                return data
773
774            assert isinstance(
775                data, typing.Mapping
776            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
777
778            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
779
780            # initialize dict for keeping what we will pass to the constructor
781            ctor_kwargs: dict[str, Any] = dict()
782
783            # iterate over the fields of the class
784            for field in dataclasses.fields(cls):
785                # check if the field is a SerializableField
786                assert isinstance(
787                    field, SerializableField
788                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
789
790                # check if the field is in the data and if it should be initialized
791                if (field.name in data) and field.init:
792                    # get the value, we will be processing it
793                    value: Any = data[field.name]
794
795                    # get the type hint for the field
796                    field_type_hint: Any = cls_type_hints.get(field.name, None)
797
798                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
799                    if field.deserialize_fn:
800                        # if it has a deserialization function, use that
801                        value = field.deserialize_fn(value)
802                    elif field.loading_fn:
803                        # if it has a loading function, use that
804                        value = field.loading_fn(data)
805                    elif (
806                        field_type_hint is not None
807                        and hasattr(field_type_hint, "load")
808                        and callable(field_type_hint.load)
809                    ):
810                        # if no loading function but has a type hint with a load method, use that
811                        if isinstance(value, dict):
812                            value = field_type_hint.load(value)
813                        else:
814                            raise FieldLoadingError(
815                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
816                            )
817                    else:
818                        # assume no loading needs to happen, keep `value` as-is
819                        pass
820
821                    # store the value in the constructor kwargs
822                    ctor_kwargs[field.name] = value
823
824            # create a new instance of the class with the constructor kwargs
825            output: cls = cls(**ctor_kwargs)
826
827            # validate the types of the fields if needed
828            if on_typecheck_mismatch != ErrorMode.IGNORE:
829                fields_valid: dict[str, bool] = (
830                    SerializableDataclass__validate_fields_types__dict(
831                        output,
832                        on_typecheck_error=on_typecheck_error,
833                    )
834                )
835
836                # if there are any fields that are not valid, raise an error
837                if not all(fields_valid.values()):
838                    msg: str = (
839                        f"Type mismatch in fields of {cls.__name__}:\n"
840                        + "\n".join(
841                            [
842                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
843                                for k, v in fields_valid.items()
844                                if not v
845                            ]
846                        )
847                    )
848
849                    on_typecheck_mismatch.process(
850                        msg, except_cls=FieldTypeMismatchError
851                    )
852
853            # return the new instance
854            return output
855
856        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
857        # type is `Callable[[T], dict]`
858        cls.serialize = serialize  # type: ignore[attr-defined]
859        # type is `Callable[[dict], T]`
860        cls.load = load  # type: ignore[attr-defined]
861        # type is `Callable[[T, ErrorMode], bool]`
862        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
863
864        # type is `Callable[[T, T], bool]`
865        if not hasattr(cls, "__eq__"):
866            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
867
868        # Register the class with ZANJ
869        if register_handler:
870            zanj_register_loader_serializable_dataclass(cls)
871
872        return cls
873
874    if _cls is None:
875        return wrap
876    else:
877        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool (defaults to True)
  • repr : bool (defaults to True)
  • order : bool (defaults to False)
  • unsafe_hash : bool (defaults to False)
  • frozen : bool (defaults to False)
  • properties_to_serialize : Optional[list[str]] SerializableDataclass only: which properties to add to the serialized data dict (defaults to None)
  • register_handler : bool SerializableDataclass only: if true, register the class with ZANJ for loading (defaults to True)
  • on_typecheck_error : ErrorMode SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false
  • on_typecheck_mismatch : ErrorMode SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True

Returns:

  • _type_ the decorated class

Raises: