docs for muutils v0.8.0
View Source on GitHub

muutils.json_serialize.serializable_dataclass

save and load objects to and from json or compatible formats in a recoverable way

d = dataclasses.asdict(my_obj) will give you a dict, but if some fields are not json-serializable, you will get an error when you call json.dumps(d). This module provides a way around that.

Instead, you define your class:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True

  1"""save and load objects to and from json or compatible formats in a recoverable way
  2
  3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
  4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
  5
  6Instead, you define your class:
  7
  8```python
  9@serializable_dataclass
 10class MyClass(SerializableDataclass):
 11    a: int
 12    b: str
 13```
 14
 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
 16
 17    >>> my_obj = MyClass(a=1, b="q")
 18    >>> s = json.dumps(my_obj.serialize())
 19    >>> s
 20    '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
 21    >>> read_obj = MyClass.load(json.loads(s))
 22    >>> read_obj == my_obj
 23    True
 24
 25This isn't too impressive on its own, but it gets more useful when you have nested classses,
 26or fields that are not json-serializable by default:
 27
 28```python
 29@serializable_dataclass
 30class NestedClass(SerializableDataclass):
 31    x: str
 32    y: MyClass
 33    act_fun: torch.nn.Module = serializable_field(
 34        default=torch.nn.ReLU(),
 35        serialization_fn=lambda x: str(x),
 36        deserialize_fn=lambda x: getattr(torch.nn, x)(),
 37    )
 38```
 39
 40which gives us:
 41
 42    >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
 43    >>> s = json.dumps(nc.serialize())
 44    >>> s
 45    '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
 46    >>> read_nc = NestedClass.load(json.loads(s))
 47    >>> read_nc == nc
 48    True
 49
 50"""
 51
 52from __future__ import annotations
 53
 54import abc
 55import dataclasses
 56import functools
 57import json
 58import sys
 59import typing
 60import warnings
 61from typing import Any, Optional, Type, TypeVar
 62
 63from muutils.errormode import ErrorMode
 64from muutils.validate_type import validate_type
 65from muutils.json_serialize.serializable_field import (
 66    SerializableField,
 67    serializable_field,
 68)
 69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq
 70
 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
 72
 73
 74def _dataclass_transform_mock(
 75    *,
 76    eq_default: bool = True,
 77    order_default: bool = False,
 78    kw_only_default: bool = False,
 79    frozen_default: bool = False,
 80    field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (),
 81    **kwargs: Any,
 82) -> typing.Callable:
 83    "mock `typing.dataclass_transform` for python <3.11"
 84
 85    def decorator(cls_or_fn):
 86        cls_or_fn.__dataclass_transform__ = {
 87            "eq_default": eq_default,
 88            "order_default": order_default,
 89            "kw_only_default": kw_only_default,
 90            "frozen_default": frozen_default,
 91            "field_specifiers": field_specifiers,
 92            "kwargs": kwargs,
 93        }
 94        return cls_or_fn
 95
 96    return decorator
 97
 98
 99if sys.version_info < (3, 11):
100    dataclass_transform = _dataclass_transform_mock
101else:
102    dataclass_transform = typing.dataclass_transform
103
104
105T = TypeVar("T")
106
107
108class CantGetTypeHintsWarning(UserWarning):
109    "special warning for when we can't get type hints"
110
111    pass
112
113
114class ZanjMissingWarning(UserWarning):
115    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
116
117    pass
118
119
120_zanj_loading_needs_import: bool = True
121"flag to keep track of if we have successfully imported ZANJ"
122
123
124def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
125    """Register a serializable dataclass with the ZANJ import
126
127    this allows `ZANJ().read()` to load the class and not just return plain dicts
128
129
130    # TODO: there is some duplication here with register_loader_handler
131    """
132    global _zanj_loading_needs_import
133
134    if _zanj_loading_needs_import:
135        try:
136            from zanj.loading import (  # type: ignore[import]
137                LoaderHandler,
138                register_loader_handler,
139            )
140        except ImportError:
141            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
142            # warnings.warn(
143            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
144            #     ZanjMissingWarning,
145            # )
146            return
147
148    _format: str = f"{cls.__name__}(SerializableDataclass)"
149    lh: LoaderHandler = LoaderHandler(
150        check=lambda json_item, path=None, z=None: (  # type: ignore
151            isinstance(json_item, dict)
152            and _FORMAT_KEY in json_item
153            and json_item[_FORMAT_KEY].startswith(_format)
154        ),
155        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
156        uid=_format,
157        source_pckg=cls.__module__,
158        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
159    )
160
161    register_loader_handler(lh)
162
163    return lh
164
165
166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
168
169
170class FieldIsNotInitOrSerializeWarning(UserWarning):
171    pass
172
173
174def SerializableDataclass__validate_field_type(
175    self: SerializableDataclass,
176    field: SerializableField | str,
177    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
178) -> bool:
179    """given a dataclass, check the field matches the type hint
180
181    this function is written to `SerializableDataclass.validate_field_type`
182
183    # Parameters:
184     - `self : SerializableDataclass`
185       `SerializableDataclass` instance
186     - `field : SerializableField | str`
187        field to validate, will get from `self.__dataclass_fields__` if an `str`
188     - `on_typecheck_error : ErrorMode`
189        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
190       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
191
192    # Returns:
193     - `bool`
194        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
195    """
196    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
197
198    # get field
199    _field: SerializableField
200    if isinstance(field, str):
201        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
202    else:
203        _field = field
204
205    # do nothing case
206    if not _field.assert_type:
207        return True
208
209    # if field is not `init` or not `serialize`, skip but warn
210    # TODO: how to handle fields which are not `init` or `serialize`?
211    if not _field.init or not _field.serialize:
212        warnings.warn(
213            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
214            FieldIsNotInitOrSerializeWarning,
215        )
216        return True
217
218    assert isinstance(
219        _field, SerializableField
220    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
221
222    # get field type hints
223    try:
224        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
225    except KeyError as e:
226        on_typecheck_error.process(
227            (
228                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
229                + f"{get_cls_type_hints(self.__class__) = }\n"
230                + f"Python version is {sys.version_info = }. You can:\n"
231                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
232                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
233                + "  - use python 3.9.x or higher\n"
234                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
235            ),
236            except_cls=TypeError,
237            except_from=e,
238        )
239        return False
240
241    # get the value
242    value: Any = getattr(self, _field.name)
243
244    # validate the type
245    try:
246        type_is_valid: bool
247        # validate the type with the default type validator
248        if _field.custom_typecheck_fn is None:
249            type_is_valid = validate_type(value, field_type_hint)
250        # validate the type with a custom type validator
251        else:
252            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
253
254        return type_is_valid
255
256    except Exception as e:
257        on_typecheck_error.process(
258            "exception while validating type: "
259            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
260            except_cls=ValueError,
261            except_from=e,
262        )
263        return False
264
265
266def SerializableDataclass__validate_fields_types__dict(
267    self: SerializableDataclass,
268    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
269) -> dict[str, bool]:
270    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
271
272    returns a dict of field names to bools, where the bool is if the field type is valid
273    """
274    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
275
276    # if except, bundle the exceptions
277    results: dict[str, bool] = dict()
278    exceptions: dict[str, Exception] = dict()
279
280    # for each field in the class
281    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
282    for field in cls_fields:
283        try:
284            results[field.name] = self.validate_field_type(field, on_typecheck_error)
285        except Exception as e:
286            results[field.name] = False
287            exceptions[field.name] = e
288
289    # figure out what to do with the exceptions
290    if len(exceptions) > 0:
291        on_typecheck_error.process(
292            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
293            + "\n\t"
294            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
295            except_cls=ValueError,
296            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
297            except_from=list(exceptions.values())[0],
298        )
299
300    return results
301
302
303def SerializableDataclass__validate_fields_types(
304    self: SerializableDataclass,
305    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
306) -> bool:
307    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
308    return all(
309        SerializableDataclass__validate_fields_types__dict(
310            self, on_typecheck_error=on_typecheck_error
311        ).values()
312    )
313
314
315@dataclass_transform(
316    field_specifiers=(serializable_field, SerializableField),
317)
318class SerializableDataclass(abc.ABC):
319    """Base class for serializable dataclasses
320
321    only for linting and type checking, still need to call `serializable_dataclass` decorator
322
323    # Usage:
324
325    ```python
326    @serializable_dataclass
327    class MyClass(SerializableDataclass):
328        a: int
329        b: str
330    ```
331
332    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
333
334        >>> my_obj = MyClass(a=1, b="q")
335        >>> s = json.dumps(my_obj.serialize())
336        >>> s
337        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
338        >>> read_obj = MyClass.load(json.loads(s))
339        >>> read_obj == my_obj
340        True
341
342    This isn't too impressive on its own, but it gets more useful when you have nested classses,
343    or fields that are not json-serializable by default:
344
345    ```python
346    @serializable_dataclass
347    class NestedClass(SerializableDataclass):
348        x: str
349        y: MyClass
350        act_fun: torch.nn.Module = serializable_field(
351            default=torch.nn.ReLU(),
352            serialization_fn=lambda x: str(x),
353            deserialize_fn=lambda x: getattr(torch.nn, x)(),
354        )
355    ```
356
357    which gives us:
358
359        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
360        >>> s = json.dumps(nc.serialize())
361        >>> s
362        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
363        >>> read_nc = NestedClass.load(json.loads(s))
364        >>> read_nc == nc
365        True
366    """
367
368    def serialize(self) -> dict[str, Any]:
369        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
370        raise NotImplementedError(
371            f"decorate {self.__class__ = } with `@serializable_dataclass`"
372        )
373
374    @classmethod
375    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
376        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
377        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
378
379    def validate_fields_types(
380        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
381    ) -> bool:
382        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
383        return SerializableDataclass__validate_fields_types(
384            self, on_typecheck_error=on_typecheck_error
385        )
386
387    def validate_field_type(
388        self,
389        field: "SerializableField|str",
390        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
391    ) -> bool:
392        """given a dataclass, check the field matches the type hint"""
393        return SerializableDataclass__validate_field_type(
394            self, field, on_typecheck_error=on_typecheck_error
395        )
396
397    def __eq__(self, other: Any) -> bool:
398        return dc_eq(self, other)
399
400    def __hash__(self) -> int:
401        "hashes the json-serialized representation of the class"
402        return hash(json.dumps(self.serialize()))
403
404    def diff(
405        self, other: "SerializableDataclass", of_serialized: bool = False
406    ) -> dict[str, Any]:
407        """get a rich and recursive diff between two instances of a serializable dataclass
408
409        ```python
410        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
411        {'b': {'self': 2, 'other': 3}}
412        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
413        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
414        ```
415
416        # Parameters:
417         - `other : SerializableDataclass`
418           other instance to compare against
419         - `of_serialized : bool`
420           if true, compare serialized data and not raw values
421           (defaults to `False`)
422
423        # Returns:
424         - `dict[str, Any]`
425
426
427        # Raises:
428         - `ValueError` : if the instances are not of the same type
429         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
430        """
431        # match types
432        if type(self) is not type(other):
433            raise ValueError(
434                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
435            )
436
437        # initialize the diff result
438        diff_result: dict = {}
439
440        # if they are the same, return the empty diff
441        try:
442            if self == other:
443                return diff_result
444        except Exception:
445            pass
446
447        # if we are working with serialized data, serialize the instances
448        if of_serialized:
449            ser_self: dict = self.serialize()
450            ser_other: dict = other.serialize()
451
452        # for each field in the class
453        for field in dataclasses.fields(self):  # type: ignore[arg-type]
454            # skip fields that are not for comparison
455            if not field.compare:
456                continue
457
458            # get values
459            field_name: str = field.name
460            self_value = getattr(self, field_name)
461            other_value = getattr(other, field_name)
462
463            # if the values are both serializable dataclasses, recurse
464            if isinstance(self_value, SerializableDataclass) and isinstance(
465                other_value, SerializableDataclass
466            ):
467                nested_diff: dict = self_value.diff(
468                    other_value, of_serialized=of_serialized
469                )
470                if nested_diff:
471                    diff_result[field_name] = nested_diff
472            # only support serializable dataclasses
473            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
474                other_value
475            ):
476                raise ValueError("Non-serializable dataclass is not supported")
477            else:
478                # get the values of either the serialized or the actual values
479                self_value_s = ser_self[field_name] if of_serialized else self_value
480                other_value_s = ser_other[field_name] if of_serialized else other_value
481                # compare the values
482                if not array_safe_eq(self_value_s, other_value_s):
483                    diff_result[field_name] = {"self": self_value, "other": other_value}
484
485        # return the diff result
486        return diff_result
487
488    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
489        """update the instance from a nested dict, useful for configuration from command line args
490
491        # Parameters:
492            - `nested_dict : dict[str, Any]`
493                nested dict to update the instance with
494        """
495        for field in dataclasses.fields(self):  # type: ignore[arg-type]
496            field_name: str = field.name
497            self_value = getattr(self, field_name)
498
499            if field_name in nested_dict:
500                if isinstance(self_value, SerializableDataclass):
501                    self_value.update_from_nested_dict(nested_dict[field_name])
502                else:
503                    setattr(self, field_name, nested_dict[field_name])
504
505    def __copy__(self) -> "SerializableDataclass":
506        "deep copy by serializing and loading the instance to json"
507        return self.__class__.load(json.loads(json.dumps(self.serialize())))
508
509    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
510        "deep copy by serializing and loading the instance to json"
511        return self.__class__.load(json.loads(json.dumps(self.serialize())))
512
513
514# cache this so we don't have to keep getting it
515# TODO: are the types hashable? does this even make sense?
516@functools.lru_cache(typed=True)
517def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
518    "cached typing.get_type_hints for a class"
519    return typing.get_type_hints(cls)
520
521
522def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
523    "helper function to get type hints for a class"
524    cls_type_hints: dict[str, Any]
525    try:
526        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
527        if len(cls_type_hints) == 0:
528            cls_type_hints = typing.get_type_hints(cls)
529
530        if len(cls_type_hints) == 0:
531            raise ValueError(f"empty type hints for {cls.__name__ = }")
532    except (TypeError, NameError, ValueError) as e:
533        raise TypeError(
534            f"Cannot get type hints for {cls = }\n"
535            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
536            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
537            + f"  {e = }"
538        ) from e
539
540    return cls_type_hints
541
542
543class KWOnlyError(NotImplementedError):
544    "kw-only dataclasses are not supported in python <3.9"
545
546    pass
547
548
549class FieldError(ValueError):
550    "base class for field errors"
551
552    pass
553
554
555class NotSerializableFieldException(FieldError):
556    "field is not a `SerializableField`"
557
558    pass
559
560
561class FieldSerializationError(FieldError):
562    "error while serializing a field"
563
564    pass
565
566
567class FieldLoadingError(FieldError):
568    "error while loading a field"
569
570    pass
571
572
573class FieldTypeMismatchError(FieldError, TypeError):
574    "error when a field type does not match the type hint"
575
576    pass
577
578
579@dataclass_transform(
580    field_specifiers=(serializable_field, SerializableField),
581)
582def serializable_dataclass(
583    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
584    _cls=None,  # type: ignore
585    *,
586    init: bool = True,
587    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
588    eq: bool = True,
589    order: bool = False,
590    unsafe_hash: bool = False,
591    frozen: bool = False,
592    properties_to_serialize: Optional[list[str]] = None,
593    register_handler: bool = True,
594    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
595    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
596    methods_no_override: list[str] | None = None,
597    **kwargs,
598):
599    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
600
601    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
602
603    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
604
605    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
606
607    Examines PEP 526 `__annotations__` to determine fields.
608
609    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
610
611    ```python
612    @serializable_dataclass(kw_only=True)
613    class Myclass(SerializableDataclass):
614        a: int
615        b: str
616    ```
617    ```python
618    >>> Myclass(a=1, b="q").serialize()
619    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
620    ```
621
622    # Parameters:
623
624    - `_cls : _type_`
625       class to decorate. don't pass this arg, just use this as a decorator
626       (defaults to `None`)
627    - `init : bool`
628       whether to add an `__init__` method
629       *(passed to dataclasses.dataclass)*
630       (defaults to `True`)
631    - `repr : bool`
632       whether to add a `__repr__` method
633       *(passed to dataclasses.dataclass)*
634       (defaults to `True`)
635    - `order : bool`
636       whether to add rich comparison methods
637       *(passed to dataclasses.dataclass)*
638       (defaults to `False`)
639    - `unsafe_hash : bool`
640       whether to add a `__hash__` method
641       *(passed to dataclasses.dataclass)*
642       (defaults to `False`)
643    - `frozen : bool`
644       whether to make the class frozen
645       *(passed to dataclasses.dataclass)*
646       (defaults to `False`)
647    - `properties_to_serialize : Optional[list[str]]`
648       which properties to add to the serialized data dict
649       **SerializableDataclass only**
650       (defaults to `None`)
651    - `register_handler : bool`
652        if true, register the class with ZANJ for loading
653        **SerializableDataclass only**
654        (defaults to `True`)
655    - `on_typecheck_error : ErrorMode`
656        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
657        **SerializableDataclass only**
658    - `on_typecheck_mismatch : ErrorMode`
659        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
660        **SerializableDataclass only**
661    - `methods_no_override : list[str]|None`
662        list of methods that should not be overridden by the decorator
663        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
664        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
665        **SerializableDataclass only**
666        (defaults to `None`)
667    - `**kwargs`
668        *(passed to dataclasses.dataclass)*
669
670    # Returns:
671
672    - `_type_`
673       the decorated class
674
675    # Raises:
676
677    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
678    - `NotSerializableFieldException` : if a field is not a `SerializableField`
679    - `FieldSerializationError` : if there is an error serializing a field
680    - `AttributeError` : if a property is not found on the class
681    - `FieldLoadingError` : if there is an error loading a field
682    """
683    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
684    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
685    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
686
687    if properties_to_serialize is None:
688        _properties_to_serialize: list = list()
689    else:
690        _properties_to_serialize = properties_to_serialize
691
692    def wrap(cls: Type[T]) -> Type[T]:
693        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
694        for field_name, field_type in cls.__annotations__.items():
695            field_value = getattr(cls, field_name, None)
696            if not isinstance(field_value, SerializableField):
697                if isinstance(field_value, dataclasses.Field):
698                    # Convert the field to a SerializableField while preserving properties
699                    field_value = SerializableField.from_Field(field_value)
700                else:
701                    # Create a new SerializableField
702                    field_value = serializable_field()
703                setattr(cls, field_name, field_value)
704
705        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
706        if sys.version_info < (3, 10):
707            if "kw_only" in kwargs:
708                if kwargs["kw_only"] == True:  # noqa: E712
709                    raise KWOnlyError(
710                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
711                    )
712                else:
713                    del kwargs["kw_only"]
714
715        # call `dataclasses.dataclass` to set some stuff up
716        cls = dataclasses.dataclass(  # type: ignore[call-overload]
717            cls,
718            init=init,
719            repr=repr,
720            eq=eq,
721            order=order,
722            unsafe_hash=unsafe_hash,
723            frozen=frozen,
724            **kwargs,
725        )
726
727        # copy these to the class
728        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
729
730        # ======================================================================
731        # define `serialize` func
732        # done locally since it depends on args to the decorator
733        # ======================================================================
734        def serialize(self) -> dict[str, Any]:
735            result: dict[str, Any] = {
736                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
737            }
738            # for each field in the class
739            for field in dataclasses.fields(self):  # type: ignore[arg-type]
740                # need it to be our special SerializableField
741                if not isinstance(field, SerializableField):
742                    raise NotSerializableFieldException(
743                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
744                        f"but a {type(field)} "
745                        "this state should be inaccessible, please report this bug!"
746                    )
747
748                # try to save it
749                if field.serialize:
750                    try:
751                        # get the val
752                        value = getattr(self, field.name)
753                        # if it is a serializable dataclass, serialize it
754                        if isinstance(value, SerializableDataclass):
755                            value = value.serialize()
756                        # if the value has a serialization function, use that
757                        if hasattr(value, "serialize") and callable(value.serialize):
758                            value = value.serialize()
759                        # if the field has a serialization function, use that
760                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
761                        elif field.serialization_fn:
762                            value = field.serialization_fn(value)
763
764                        # store the value in the result
765                        result[field.name] = value
766                    except Exception as e:
767                        raise FieldSerializationError(
768                            "\n".join(
769                                [
770                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
771                                    f"{field = }",
772                                    f"{value = }",
773                                    f"{self = }",
774                                ]
775                            )
776                        ) from e
777
778            # store each property if we can get it
779            for prop in self._properties_to_serialize:
780                if hasattr(cls, prop):
781                    value = getattr(self, prop)
782                    result[prop] = value
783                else:
784                    raise AttributeError(
785                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
786                        + f"but it is in {self._properties_to_serialize = }"
787                        + f"\n{self = }"
788                    )
789
790            return result
791
792        # ======================================================================
793        # define `load` func
794        # done locally since it depends on args to the decorator
795        # ======================================================================
796        # mypy thinks this isnt a classmethod
797        @classmethod  # type: ignore[misc]
798        def load(cls, data: dict[str, Any] | T) -> Type[T]:
799            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
800            if isinstance(data, cls):
801                return data
802
803            assert isinstance(
804                data, typing.Mapping
805            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
806
807            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
808
809            # initialize dict for keeping what we will pass to the constructor
810            ctor_kwargs: dict[str, Any] = dict()
811
812            # iterate over the fields of the class
813            for field in dataclasses.fields(cls):
814                # check if the field is a SerializableField
815                assert isinstance(
816                    field, SerializableField
817                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
818
819                # check if the field is in the data and if it should be initialized
820                if (field.name in data) and field.init:
821                    # get the value, we will be processing it
822                    value: Any = data[field.name]
823
824                    # get the type hint for the field
825                    field_type_hint: Any = cls_type_hints.get(field.name, None)
826
827                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
828                    if field.deserialize_fn:
829                        # if it has a deserialization function, use that
830                        value = field.deserialize_fn(value)
831                    elif field.loading_fn:
832                        # if it has a loading function, use that
833                        value = field.loading_fn(data)
834                    elif (
835                        field_type_hint is not None
836                        and hasattr(field_type_hint, "load")
837                        and callable(field_type_hint.load)
838                    ):
839                        # if no loading function but has a type hint with a load method, use that
840                        if isinstance(value, dict):
841                            value = field_type_hint.load(value)
842                        else:
843                            raise FieldLoadingError(
844                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
845                            )
846                    else:
847                        # assume no loading needs to happen, keep `value` as-is
848                        pass
849
850                    # store the value in the constructor kwargs
851                    ctor_kwargs[field.name] = value
852
853            # create a new instance of the class with the constructor kwargs
854            output: cls = cls(**ctor_kwargs)
855
856            # validate the types of the fields if needed
857            if on_typecheck_mismatch != ErrorMode.IGNORE:
858                fields_valid: dict[str, bool] = (
859                    SerializableDataclass__validate_fields_types__dict(
860                        output,
861                        on_typecheck_error=on_typecheck_error,
862                    )
863                )
864
865                # if there are any fields that are not valid, raise an error
866                if not all(fields_valid.values()):
867                    msg: str = (
868                        f"Type mismatch in fields of {cls.__name__}:\n"
869                        + "\n".join(
870                            [
871                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
872                                for k, v in fields_valid.items()
873                                if not v
874                            ]
875                        )
876                    )
877
878                    on_typecheck_mismatch.process(
879                        msg, except_cls=FieldTypeMismatchError
880                    )
881
882            # return the new instance
883            return output
884
885        _methods_no_override: set[str]
886        if methods_no_override is None:
887            _methods_no_override = set()
888        else:
889            _methods_no_override = set(methods_no_override)
890
891        if _methods_no_override - {
892            "__eq__",
893            "serialize",
894            "load",
895            "validate_fields_types",
896        }:
897            warnings.warn(
898                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
899            )
900
901        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
902        if "serialize" not in _methods_no_override:
903            # type is `Callable[[T], dict]`
904            cls.serialize = serialize  # type: ignore[attr-defined]
905        if "load" not in _methods_no_override:
906            # type is `Callable[[dict], T]`
907            cls.load = load  # type: ignore[attr-defined]
908
909        if "validate_field_type" not in _methods_no_override:
910            # type is `Callable[[T, ErrorMode], bool]`
911            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
912
913        if "__eq__" not in _methods_no_override:
914            # type is `Callable[[T, T], bool]`
915            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
916
917        # Register the class with ZANJ
918        if register_handler:
919            zanj_register_loader_serializable_dataclass(cls)
920
921        return cls
922
923    if _cls is None:
924        return wrap
925    else:
926        return wrap(_cls)

class CantGetTypeHintsWarning(builtins.UserWarning):
109class CantGetTypeHintsWarning(UserWarning):
110    "special warning for when we can't get type hints"
111
112    pass

special warning for when we can't get type hints

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
class ZanjMissingWarning(builtins.UserWarning):
115class ZanjMissingWarning(UserWarning):
116    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
117
118    pass

special warning for when ZANJ is missing -- register_loader_serializable_dataclass will not work

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def zanj_register_loader_serializable_dataclass(cls: Type[~T]):
125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
126    """Register a serializable dataclass with the ZANJ import
127
128    this allows `ZANJ().read()` to load the class and not just return plain dicts
129
130
131    # TODO: there is some duplication here with register_loader_handler
132    """
133    global _zanj_loading_needs_import
134
135    if _zanj_loading_needs_import:
136        try:
137            from zanj.loading import (  # type: ignore[import]
138                LoaderHandler,
139                register_loader_handler,
140            )
141        except ImportError:
142            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
143            # warnings.warn(
144            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
145            #     ZanjMissingWarning,
146            # )
147            return
148
149    _format: str = f"{cls.__name__}(SerializableDataclass)"
150    lh: LoaderHandler = LoaderHandler(
151        check=lambda json_item, path=None, z=None: (  # type: ignore
152            isinstance(json_item, dict)
153            and _FORMAT_KEY in json_item
154            and json_item[_FORMAT_KEY].startswith(_format)
155        ),
156        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
157        uid=_format,
158        source_pckg=cls.__module__,
159        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
160    )
161
162    register_loader_handler(lh)
163
164    return lh

Register a serializable dataclass with the ZANJ import

this allows ZANJ().read() to load the class and not just return plain dicts

TODO: there is some duplication here with register_loader_handler

class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):
171class FieldIsNotInitOrSerializeWarning(UserWarning):
172    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
175def SerializableDataclass__validate_field_type(
176    self: SerializableDataclass,
177    field: SerializableField | str,
178    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
179) -> bool:
180    """given a dataclass, check the field matches the type hint
181
182    this function is written to `SerializableDataclass.validate_field_type`
183
184    # Parameters:
185     - `self : SerializableDataclass`
186       `SerializableDataclass` instance
187     - `field : SerializableField | str`
188        field to validate, will get from `self.__dataclass_fields__` if an `str`
189     - `on_typecheck_error : ErrorMode`
190        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
191       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
192
193    # Returns:
194     - `bool`
195        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
196    """
197    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
198
199    # get field
200    _field: SerializableField
201    if isinstance(field, str):
202        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
203    else:
204        _field = field
205
206    # do nothing case
207    if not _field.assert_type:
208        return True
209
210    # if field is not `init` or not `serialize`, skip but warn
211    # TODO: how to handle fields which are not `init` or `serialize`?
212    if not _field.init or not _field.serialize:
213        warnings.warn(
214            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
215            FieldIsNotInitOrSerializeWarning,
216        )
217        return True
218
219    assert isinstance(
220        _field, SerializableField
221    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
222
223    # get field type hints
224    try:
225        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
226    except KeyError as e:
227        on_typecheck_error.process(
228            (
229                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
230                + f"{get_cls_type_hints(self.__class__) = }\n"
231                + f"Python version is {sys.version_info = }. You can:\n"
232                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
233                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
234                + "  - use python 3.9.x or higher\n"
235                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
236            ),
237            except_cls=TypeError,
238            except_from=e,
239        )
240        return False
241
242    # get the value
243    value: Any = getattr(self, _field.name)
244
245    # validate the type
246    try:
247        type_is_valid: bool
248        # validate the type with the default type validator
249        if _field.custom_typecheck_fn is None:
250            type_is_valid = validate_type(value, field_type_hint)
251        # validate the type with a custom type validator
252        else:
253            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
254
255        return type_is_valid
256
257    except Exception as e:
258        on_typecheck_error.process(
259            "exception while validating type: "
260            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
261            except_cls=ValueError,
262            except_from=e,
263        )
264        return False

given a dataclass, check the field matches the type hint

this function is written to SerializableDataclass.validate_field_type

Parameters:

  • self : SerializableDataclass SerializableDataclass instance
  • field : SerializableField | str field to validate, will get from self.__dataclass_fields__ if an str
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, the function will return False (defaults to _DEFAULT_ON_TYPECHECK_ERROR)

Returns:

  • bool if the field type is correct. False if the field type is incorrect or an exception is thrown and on_typecheck_error is ignore
def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> dict[str, bool]:
267def SerializableDataclass__validate_fields_types__dict(
268    self: SerializableDataclass,
269    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
270) -> dict[str, bool]:
271    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
272
273    returns a dict of field names to bools, where the bool is if the field type is valid
274    """
275    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
276
277    # if except, bundle the exceptions
278    results: dict[str, bool] = dict()
279    exceptions: dict[str, Exception] = dict()
280
281    # for each field in the class
282    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
283    for field in cls_fields:
284        try:
285            results[field.name] = self.validate_field_type(field, on_typecheck_error)
286        except Exception as e:
287            results[field.name] = False
288            exceptions[field.name] = e
289
290    # figure out what to do with the exceptions
291    if len(exceptions) > 0:
292        on_typecheck_error.process(
293            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
294            + "\n\t"
295            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
296            except_cls=ValueError,
297            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
298            except_from=list(exceptions.values())[0],
299        )
300
301    return results

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

returns a dict of field names to bools, where the bool is if the field type is valid

def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
304def SerializableDataclass__validate_fields_types(
305    self: SerializableDataclass,
306    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
307) -> bool:
308    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
309    return all(
310        SerializableDataclass__validate_fields_types__dict(
311            self, on_typecheck_error=on_typecheck_error
312        ).values()
313    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
316@dataclass_transform(
317    field_specifiers=(serializable_field, SerializableField),
318)
319class SerializableDataclass(abc.ABC):
320    """Base class for serializable dataclasses
321
322    only for linting and type checking, still need to call `serializable_dataclass` decorator
323
324    # Usage:
325
326    ```python
327    @serializable_dataclass
328    class MyClass(SerializableDataclass):
329        a: int
330        b: str
331    ```
332
333    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334
335        >>> my_obj = MyClass(a=1, b="q")
336        >>> s = json.dumps(my_obj.serialize())
337        >>> s
338        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
339        >>> read_obj = MyClass.load(json.loads(s))
340        >>> read_obj == my_obj
341        True
342
343    This isn't too impressive on its own, but it gets more useful when you have nested classses,
344    or fields that are not json-serializable by default:
345
346    ```python
347    @serializable_dataclass
348    class NestedClass(SerializableDataclass):
349        x: str
350        y: MyClass
351        act_fun: torch.nn.Module = serializable_field(
352            default=torch.nn.ReLU(),
353            serialization_fn=lambda x: str(x),
354            deserialize_fn=lambda x: getattr(torch.nn, x)(),
355        )
356    ```
357
358    which gives us:
359
360        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
361        >>> s = json.dumps(nc.serialize())
362        >>> s
363        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
364        >>> read_nc = NestedClass.load(json.loads(s))
365        >>> read_nc == nc
366        True
367    """
368
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )
374
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )
387
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )
397
398    def __eq__(self, other: Any) -> bool:
399        return dc_eq(self, other)
400
401    def __hash__(self) -> int:
402        "hashes the json-serialized representation of the class"
403        return hash(json.dumps(self.serialize()))
404
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        try:
443            if self == other:
444                return diff_result
445        except Exception:
446            pass
447
448        # if we are working with serialized data, serialize the instances
449        if of_serialized:
450            ser_self: dict = self.serialize()
451            ser_other: dict = other.serialize()
452
453        # for each field in the class
454        for field in dataclasses.fields(self):  # type: ignore[arg-type]
455            # skip fields that are not for comparison
456            if not field.compare:
457                continue
458
459            # get values
460            field_name: str = field.name
461            self_value = getattr(self, field_name)
462            other_value = getattr(other, field_name)
463
464            # if the values are both serializable dataclasses, recurse
465            if isinstance(self_value, SerializableDataclass) and isinstance(
466                other_value, SerializableDataclass
467            ):
468                nested_diff: dict = self_value.diff(
469                    other_value, of_serialized=of_serialized
470                )
471                if nested_diff:
472                    diff_result[field_name] = nested_diff
473            # only support serializable dataclasses
474            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
475                other_value
476            ):
477                raise ValueError("Non-serializable dataclass is not supported")
478            else:
479                # get the values of either the serialized or the actual values
480                self_value_s = ser_self[field_name] if of_serialized else self_value
481                other_value_s = ser_other[field_name] if of_serialized else other_value
482                # compare the values
483                if not array_safe_eq(self_value_s, other_value_s):
484                    diff_result[field_name] = {"self": self_value, "other": other_value}
485
486        # return the diff result
487        return diff_result
488
489    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
490        """update the instance from a nested dict, useful for configuration from command line args
491
492        # Parameters:
493            - `nested_dict : dict[str, Any]`
494                nested dict to update the instance with
495        """
496        for field in dataclasses.fields(self):  # type: ignore[arg-type]
497            field_name: str = field.name
498            self_value = getattr(self, field_name)
499
500            if field_name in nested_dict:
501                if isinstance(self_value, SerializableDataclass):
502                    self_value.update_from_nested_dict(nested_dict[field_name])
503                else:
504                    setattr(self, field_name, nested_dict[field_name])
505
506    def __copy__(self) -> "SerializableDataclass":
507        "deep copy by serializing and loading the instance to json"
508        return self.__class__.load(json.loads(json.dumps(self.serialize())))
509
510    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
511        "deep copy by serializing and loading the instance to json"
512        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        try:
443            if self == other:
444                return diff_result
445        except Exception:
446            pass
447
448        # if we are working with serialized data, serialize the instances
449        if of_serialized:
450            ser_self: dict = self.serialize()
451            ser_other: dict = other.serialize()
452
453        # for each field in the class
454        for field in dataclasses.fields(self):  # type: ignore[arg-type]
455            # skip fields that are not for comparison
456            if not field.compare:
457                continue
458
459            # get values
460            field_name: str = field.name
461            self_value = getattr(self, field_name)
462            other_value = getattr(other, field_name)
463
464            # if the values are both serializable dataclasses, recurse
465            if isinstance(self_value, SerializableDataclass) and isinstance(
466                other_value, SerializableDataclass
467            ):
468                nested_diff: dict = self_value.diff(
469                    other_value, of_serialized=of_serialized
470                )
471                if nested_diff:
472                    diff_result[field_name] = nested_diff
473            # only support serializable dataclasses
474            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
475                other_value
476            ):
477                raise ValueError("Non-serializable dataclass is not supported")
478            else:
479                # get the values of either the serialized or the actual values
480                self_value_s = ser_self[field_name] if of_serialized else self_value
481                other_value_s = ser_other[field_name] if of_serialized else other_value
482                # compare the values
483                if not array_safe_eq(self_value_s, other_value_s):
484                    diff_result[field_name] = {"self": self_value, "other": other_value}
485
486        # return the diff result
487        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
489    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
490        """update the instance from a nested dict, useful for configuration from command line args
491
492        # Parameters:
493            - `nested_dict : dict[str, Any]`
494                nested dict to update the instance with
495        """
496        for field in dataclasses.fields(self):  # type: ignore[arg-type]
497            field_name: str = field.name
498            self_value = getattr(self, field_name)
499
500            if field_name in nested_dict:
501                if isinstance(self_value, SerializableDataclass):
502                    self_value.update_from_nested_dict(nested_dict[field_name])
503                else:
504                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with
@functools.lru_cache(typed=True)
def get_cls_type_hints_cached(cls: Type[~T]) -> dict[str, typing.Any]:
517@functools.lru_cache(typed=True)
518def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
519    "cached typing.get_type_hints for a class"
520    return typing.get_type_hints(cls)

cached typing.get_type_hints for a class

def get_cls_type_hints(cls: Type[~T]) -> dict[str, typing.Any]:
523def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
524    "helper function to get type hints for a class"
525    cls_type_hints: dict[str, Any]
526    try:
527        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
528        if len(cls_type_hints) == 0:
529            cls_type_hints = typing.get_type_hints(cls)
530
531        if len(cls_type_hints) == 0:
532            raise ValueError(f"empty type hints for {cls.__name__ = }")
533    except (TypeError, NameError, ValueError) as e:
534        raise TypeError(
535            f"Cannot get type hints for {cls = }\n"
536            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
537            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
538            + f"  {e = }"
539        ) from e
540
541    return cls_type_hints

helper function to get type hints for a class

class KWOnlyError(builtins.NotImplementedError):
544class KWOnlyError(NotImplementedError):
545    "kw-only dataclasses are not supported in python <3.9"
546
547    pass

kw-only dataclasses are not supported in python <3.9

Inherited Members
builtins.NotImplementedError
NotImplementedError
builtins.BaseException
with_traceback
add_note
args
class FieldError(builtins.ValueError):
550class FieldError(ValueError):
551    "base class for field errors"
552
553    pass

base class for field errors

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class NotSerializableFieldException(FieldError):
556class NotSerializableFieldException(FieldError):
557    "field is not a `SerializableField`"
558
559    pass

field is not a SerializableField

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldSerializationError(FieldError):
562class FieldSerializationError(FieldError):
563    "error while serializing a field"
564
565    pass

error while serializing a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldLoadingError(FieldError):
568class FieldLoadingError(FieldError):
569    "error while loading a field"
570
571    pass

error while loading a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldTypeMismatchError(FieldError, builtins.TypeError):
574class FieldTypeMismatchError(FieldError, TypeError):
575    "error when a field type does not match the type hint"
576
577    pass

error when a field type does not match the type hint

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, methods_no_override: list[str] | None = None, **kwargs):
580@dataclass_transform(
581    field_specifiers=(serializable_field, SerializableField),
582)
583def serializable_dataclass(
584    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
585    _cls=None,  # type: ignore
586    *,
587    init: bool = True,
588    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
589    eq: bool = True,
590    order: bool = False,
591    unsafe_hash: bool = False,
592    frozen: bool = False,
593    properties_to_serialize: Optional[list[str]] = None,
594    register_handler: bool = True,
595    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
596    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
597    methods_no_override: list[str] | None = None,
598    **kwargs,
599):
600    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
601
602    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
603
604    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
605
606    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
607
608    Examines PEP 526 `__annotations__` to determine fields.
609
610    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
611
612    ```python
613    @serializable_dataclass(kw_only=True)
614    class Myclass(SerializableDataclass):
615        a: int
616        b: str
617    ```
618    ```python
619    >>> Myclass(a=1, b="q").serialize()
620    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
621    ```
622
623    # Parameters:
624
625    - `_cls : _type_`
626       class to decorate. don't pass this arg, just use this as a decorator
627       (defaults to `None`)
628    - `init : bool`
629       whether to add an `__init__` method
630       *(passed to dataclasses.dataclass)*
631       (defaults to `True`)
632    - `repr : bool`
633       whether to add a `__repr__` method
634       *(passed to dataclasses.dataclass)*
635       (defaults to `True`)
636    - `order : bool`
637       whether to add rich comparison methods
638       *(passed to dataclasses.dataclass)*
639       (defaults to `False`)
640    - `unsafe_hash : bool`
641       whether to add a `__hash__` method
642       *(passed to dataclasses.dataclass)*
643       (defaults to `False`)
644    - `frozen : bool`
645       whether to make the class frozen
646       *(passed to dataclasses.dataclass)*
647       (defaults to `False`)
648    - `properties_to_serialize : Optional[list[str]]`
649       which properties to add to the serialized data dict
650       **SerializableDataclass only**
651       (defaults to `None`)
652    - `register_handler : bool`
653        if true, register the class with ZANJ for loading
654        **SerializableDataclass only**
655        (defaults to `True`)
656    - `on_typecheck_error : ErrorMode`
657        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
658        **SerializableDataclass only**
659    - `on_typecheck_mismatch : ErrorMode`
660        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
661        **SerializableDataclass only**
662    - `methods_no_override : list[str]|None`
663        list of methods that should not be overridden by the decorator
664        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
665        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
666        **SerializableDataclass only**
667        (defaults to `None`)
668    - `**kwargs`
669        *(passed to dataclasses.dataclass)*
670
671    # Returns:
672
673    - `_type_`
674       the decorated class
675
676    # Raises:
677
678    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
679    - `NotSerializableFieldException` : if a field is not a `SerializableField`
680    - `FieldSerializationError` : if there is an error serializing a field
681    - `AttributeError` : if a property is not found on the class
682    - `FieldLoadingError` : if there is an error loading a field
683    """
684    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
685    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
686    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
687
688    if properties_to_serialize is None:
689        _properties_to_serialize: list = list()
690    else:
691        _properties_to_serialize = properties_to_serialize
692
693    def wrap(cls: Type[T]) -> Type[T]:
694        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
695        for field_name, field_type in cls.__annotations__.items():
696            field_value = getattr(cls, field_name, None)
697            if not isinstance(field_value, SerializableField):
698                if isinstance(field_value, dataclasses.Field):
699                    # Convert the field to a SerializableField while preserving properties
700                    field_value = SerializableField.from_Field(field_value)
701                else:
702                    # Create a new SerializableField
703                    field_value = serializable_field()
704                setattr(cls, field_name, field_value)
705
706        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
707        if sys.version_info < (3, 10):
708            if "kw_only" in kwargs:
709                if kwargs["kw_only"] == True:  # noqa: E712
710                    raise KWOnlyError(
711                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
712                    )
713                else:
714                    del kwargs["kw_only"]
715
716        # call `dataclasses.dataclass` to set some stuff up
717        cls = dataclasses.dataclass(  # type: ignore[call-overload]
718            cls,
719            init=init,
720            repr=repr,
721            eq=eq,
722            order=order,
723            unsafe_hash=unsafe_hash,
724            frozen=frozen,
725            **kwargs,
726        )
727
728        # copy these to the class
729        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
730
731        # ======================================================================
732        # define `serialize` func
733        # done locally since it depends on args to the decorator
734        # ======================================================================
735        def serialize(self) -> dict[str, Any]:
736            result: dict[str, Any] = {
737                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
738            }
739            # for each field in the class
740            for field in dataclasses.fields(self):  # type: ignore[arg-type]
741                # need it to be our special SerializableField
742                if not isinstance(field, SerializableField):
743                    raise NotSerializableFieldException(
744                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
745                        f"but a {type(field)} "
746                        "this state should be inaccessible, please report this bug!"
747                    )
748
749                # try to save it
750                if field.serialize:
751                    try:
752                        # get the val
753                        value = getattr(self, field.name)
754                        # if it is a serializable dataclass, serialize it
755                        if isinstance(value, SerializableDataclass):
756                            value = value.serialize()
757                        # if the value has a serialization function, use that
758                        if hasattr(value, "serialize") and callable(value.serialize):
759                            value = value.serialize()
760                        # if the field has a serialization function, use that
761                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
762                        elif field.serialization_fn:
763                            value = field.serialization_fn(value)
764
765                        # store the value in the result
766                        result[field.name] = value
767                    except Exception as e:
768                        raise FieldSerializationError(
769                            "\n".join(
770                                [
771                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
772                                    f"{field = }",
773                                    f"{value = }",
774                                    f"{self = }",
775                                ]
776                            )
777                        ) from e
778
779            # store each property if we can get it
780            for prop in self._properties_to_serialize:
781                if hasattr(cls, prop):
782                    value = getattr(self, prop)
783                    result[prop] = value
784                else:
785                    raise AttributeError(
786                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
787                        + f"but it is in {self._properties_to_serialize = }"
788                        + f"\n{self = }"
789                    )
790
791            return result
792
793        # ======================================================================
794        # define `load` func
795        # done locally since it depends on args to the decorator
796        # ======================================================================
797        # mypy thinks this isnt a classmethod
798        @classmethod  # type: ignore[misc]
799        def load(cls, data: dict[str, Any] | T) -> Type[T]:
800            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
801            if isinstance(data, cls):
802                return data
803
804            assert isinstance(
805                data, typing.Mapping
806            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
807
808            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
809
810            # initialize dict for keeping what we will pass to the constructor
811            ctor_kwargs: dict[str, Any] = dict()
812
813            # iterate over the fields of the class
814            for field in dataclasses.fields(cls):
815                # check if the field is a SerializableField
816                assert isinstance(
817                    field, SerializableField
818                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
819
820                # check if the field is in the data and if it should be initialized
821                if (field.name in data) and field.init:
822                    # get the value, we will be processing it
823                    value: Any = data[field.name]
824
825                    # get the type hint for the field
826                    field_type_hint: Any = cls_type_hints.get(field.name, None)
827
828                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
829                    if field.deserialize_fn:
830                        # if it has a deserialization function, use that
831                        value = field.deserialize_fn(value)
832                    elif field.loading_fn:
833                        # if it has a loading function, use that
834                        value = field.loading_fn(data)
835                    elif (
836                        field_type_hint is not None
837                        and hasattr(field_type_hint, "load")
838                        and callable(field_type_hint.load)
839                    ):
840                        # if no loading function but has a type hint with a load method, use that
841                        if isinstance(value, dict):
842                            value = field_type_hint.load(value)
843                        else:
844                            raise FieldLoadingError(
845                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
846                            )
847                    else:
848                        # assume no loading needs to happen, keep `value` as-is
849                        pass
850
851                    # store the value in the constructor kwargs
852                    ctor_kwargs[field.name] = value
853
854            # create a new instance of the class with the constructor kwargs
855            output: cls = cls(**ctor_kwargs)
856
857            # validate the types of the fields if needed
858            if on_typecheck_mismatch != ErrorMode.IGNORE:
859                fields_valid: dict[str, bool] = (
860                    SerializableDataclass__validate_fields_types__dict(
861                        output,
862                        on_typecheck_error=on_typecheck_error,
863                    )
864                )
865
866                # if there are any fields that are not valid, raise an error
867                if not all(fields_valid.values()):
868                    msg: str = (
869                        f"Type mismatch in fields of {cls.__name__}:\n"
870                        + "\n".join(
871                            [
872                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
873                                for k, v in fields_valid.items()
874                                if not v
875                            ]
876                        )
877                    )
878
879                    on_typecheck_mismatch.process(
880                        msg, except_cls=FieldTypeMismatchError
881                    )
882
883            # return the new instance
884            return output
885
886        _methods_no_override: set[str]
887        if methods_no_override is None:
888            _methods_no_override = set()
889        else:
890            _methods_no_override = set(methods_no_override)
891
892        if _methods_no_override - {
893            "__eq__",
894            "serialize",
895            "load",
896            "validate_fields_types",
897        }:
898            warnings.warn(
899                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
900            )
901
902        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
903        if "serialize" not in _methods_no_override:
904            # type is `Callable[[T], dict]`
905            cls.serialize = serialize  # type: ignore[attr-defined]
906        if "load" not in _methods_no_override:
907            # type is `Callable[[dict], T]`
908            cls.load = load  # type: ignore[attr-defined]
909
910        if "validate_field_type" not in _methods_no_override:
911            # type is `Callable[[T, ErrorMode], bool]`
912            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
913
914        if "__eq__" not in _methods_no_override:
915            # type is `Callable[[T, T], bool]`
916            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
917
918        # Register the class with ZANJ
919        if register_handler:
920            zanj_register_loader_serializable_dataclass(cls)
921
922        return cls
923
924    if _cls is None:
925        return wrap
926    else:
927        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass!!

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool whether to add an __init__ method (passed to dataclasses.dataclass) (defaults to True)
  • repr : bool whether to add a __repr__ method (passed to dataclasses.dataclass) (defaults to True)
  • order : bool whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults to False)
  • unsafe_hash : bool whether to add a __hash__ method (passed to dataclasses.dataclass) (defaults to False)
  • frozen : bool whether to make the class frozen (passed to dataclasses.dataclass) (defaults to False)
  • properties_to_serialize : Optional[list[str]] which properties to add to the serialized data dict SerializableDataclass only (defaults to None)
  • register_handler : bool if true, register the class with ZANJ for loading SerializableDataclass only (defaults to True)
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false SerializableDataclass only
  • on_typecheck_mismatch : ErrorMode what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True SerializableDataclass only
  • methods_no_override : list[str]|None list of methods that should not be overridden by the decorator by default, __eq__, serialize, load, and validate_fields_types are overridden by this function, but you can disable this if you'd rather write your own. dataclasses.dataclass might still overwrite these, and those options take precedence SerializableDataclass only (defaults to None)
  • **kwargs (passed to dataclasses.dataclass)

Returns:

  • _type_ the decorated class

Raises: