docs for muutils v0.7.0
View Source on GitHub

muutils.json_serialize

submodule for serializing things to json in a recoverable way

you can throw any object into muutils.json_serialize.json_serialize and it will return a JSONitem, meaning a bool, int, float, str, None, list of JSONitems, or a dict mappting to JSONitem.

The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize and it will just work. If you want to do so in a recoverable way, check out ZANJ.

it will do so by looking in DEFAULT_HANDLERS, which will keep it as-is if its already valid, then try to find a .serialize() method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer object and passing a sequence of them to handlers_pre

additionally, SerializeableDataclass is a special kind of dataclass where you specify how to serialize each field, and a .serialize() method is automatically added to the class. This is done by using the serializable_dataclass decorator, inheriting from SerializeableDataclass, and serializable_field in place of dataclasses.field when defining non-standard fields.

This module plays nicely with and is a dependency of the ZANJ library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.


 1"""submodule for serializing things to json in a recoverable way
 2
 3you can throw *any* object into `muutils.json_serialize.json_serialize`
 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`.
 5
 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ).
 7
 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre`
 9
10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields.
11
12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
13
14"""
15
16from __future__ import annotations
17
18from muutils.json_serialize.array import arr_metadata, load_array
19from muutils.json_serialize.json_serialize import (
20    BASE_HANDLERS,
21    JsonSerializer,
22    json_serialize,
23)
24from muutils.json_serialize.serializable_dataclass import (
25    SerializableDataclass,
26    serializable_dataclass,
27    serializable_field,
28)
29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq
30
31__all__ = [
32    # submodules
33    "array",
34    "json_serialize",
35    "serializable_dataclass",
36    "serializable_field",
37    "util",
38    # imports
39    "arr_metadata",
40    "load_array",
41    "BASE_HANDLERS",
42    "JSONitem",
43    "JsonSerializer",
44    "json_serialize",
45    "try_catch",
46    "JSONitem",
47    "dc_eq",
48    "serializable_dataclass",
49    "serializable_field",
50    "SerializableDataclass",
51]

def json_serialize( obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]:
330def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
331    """serialize object to json-serializable object with default config"""
332    return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)

serialize object to json-serializable object with default config

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, **kwargs):
578@dataclass_transform(
579    field_specifiers=(serializable_field, SerializableField),
580)
581def serializable_dataclass(
582    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
583    _cls=None,  # type: ignore
584    *,
585    init: bool = True,
586    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
587    eq: bool = True,
588    order: bool = False,
589    unsafe_hash: bool = False,
590    frozen: bool = False,
591    properties_to_serialize: Optional[list[str]] = None,
592    register_handler: bool = True,
593    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
594    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
595    **kwargs,
596):
597    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
598
599    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
600
601    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
602
603    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
604
605    Examines PEP 526 `__annotations__` to determine fields.
606
607    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
608
609    ```python
610    @serializable_dataclass(kw_only=True)
611    class Myclass(SerializableDataclass):
612        a: int
613        b: str
614    ```
615    ```python
616    >>> Myclass(a=1, b="q").serialize()
617    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
618    ```
619
620    # Parameters:
621     - `_cls : _type_`
622       class to decorate. don't pass this arg, just use this as a decorator
623       (defaults to `None`)
624     - `init : bool`
625       (defaults to `True`)
626     - `repr : bool`
627       (defaults to `True`)
628     - `order : bool`
629       (defaults to `False`)
630     - `unsafe_hash : bool`
631       (defaults to `False`)
632     - `frozen : bool`
633       (defaults to `False`)
634     - `properties_to_serialize : Optional[list[str]]`
635       **SerializableDataclass only:** which properties to add to the serialized data dict
636       (defaults to `None`)
637     - `register_handler : bool`
638        **SerializableDataclass only:** if true, register the class with ZANJ for loading
639       (defaults to `True`)
640     - `on_typecheck_error : ErrorMode`
641        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
642     - `on_typecheck_mismatch : ErrorMode`
643        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
644
645    # Returns:
646     - `_type_`
647       the decorated class
648
649    # Raises:
650     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
651     - `NotSerializableFieldException` : if a field is not a `SerializableField`
652     - `FieldSerializationError` : if there is an error serializing a field
653     - `AttributeError` : if a property is not found on the class
654     - `FieldLoadingError` : if there is an error loading a field
655    """
656    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
657    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
658    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
659
660    if properties_to_serialize is None:
661        _properties_to_serialize: list = list()
662    else:
663        _properties_to_serialize = properties_to_serialize
664
665    def wrap(cls: Type[T]) -> Type[T]:
666        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
667        for field_name, field_type in cls.__annotations__.items():
668            field_value = getattr(cls, field_name, None)
669            if not isinstance(field_value, SerializableField):
670                if isinstance(field_value, dataclasses.Field):
671                    # Convert the field to a SerializableField while preserving properties
672                    field_value = SerializableField.from_Field(field_value)
673                else:
674                    # Create a new SerializableField
675                    field_value = serializable_field()
676                setattr(cls, field_name, field_value)
677
678        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
679        if sys.version_info < (3, 10):
680            if "kw_only" in kwargs:
681                if kwargs["kw_only"] == True:  # noqa: E712
682                    raise KWOnlyError("kw_only is not supported in python >=3.9")
683                else:
684                    del kwargs["kw_only"]
685
686        # call `dataclasses.dataclass` to set some stuff up
687        cls = dataclasses.dataclass(  # type: ignore[call-overload]
688            cls,
689            init=init,
690            repr=repr,
691            eq=eq,
692            order=order,
693            unsafe_hash=unsafe_hash,
694            frozen=frozen,
695            **kwargs,
696        )
697
698        # copy these to the class
699        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
700
701        # ======================================================================
702        # define `serialize` func
703        # done locally since it depends on args to the decorator
704        # ======================================================================
705        def serialize(self) -> dict[str, Any]:
706            result: dict[str, Any] = {
707                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
708            }
709            # for each field in the class
710            for field in dataclasses.fields(self):  # type: ignore[arg-type]
711                # need it to be our special SerializableField
712                if not isinstance(field, SerializableField):
713                    raise NotSerializableFieldException(
714                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
715                        f"but a {type(field)} "
716                        "this state should be inaccessible, please report this bug!"
717                    )
718
719                # try to save it
720                if field.serialize:
721                    try:
722                        # get the val
723                        value = getattr(self, field.name)
724                        # if it is a serializable dataclass, serialize it
725                        if isinstance(value, SerializableDataclass):
726                            value = value.serialize()
727                        # if the value has a serialization function, use that
728                        if hasattr(value, "serialize") and callable(value.serialize):
729                            value = value.serialize()
730                        # if the field has a serialization function, use that
731                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
732                        elif field.serialization_fn:
733                            value = field.serialization_fn(value)
734
735                        # store the value in the result
736                        result[field.name] = value
737                    except Exception as e:
738                        raise FieldSerializationError(
739                            "\n".join(
740                                [
741                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
742                                    f"{field = }",
743                                    f"{value = }",
744                                    f"{self = }",
745                                ]
746                            )
747                        ) from e
748
749            # store each property if we can get it
750            for prop in self._properties_to_serialize:
751                if hasattr(cls, prop):
752                    value = getattr(self, prop)
753                    result[prop] = value
754                else:
755                    raise AttributeError(
756                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
757                        + f"but it is in {self._properties_to_serialize = }"
758                        + f"\n{self = }"
759                    )
760
761            return result
762
763        # ======================================================================
764        # define `load` func
765        # done locally since it depends on args to the decorator
766        # ======================================================================
767        # mypy thinks this isnt a classmethod
768        @classmethod  # type: ignore[misc]
769        def load(cls, data: dict[str, Any] | T) -> Type[T]:
770            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
771            if isinstance(data, cls):
772                return data
773
774            assert isinstance(
775                data, typing.Mapping
776            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
777
778            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
779
780            # initialize dict for keeping what we will pass to the constructor
781            ctor_kwargs: dict[str, Any] = dict()
782
783            # iterate over the fields of the class
784            for field in dataclasses.fields(cls):
785                # check if the field is a SerializableField
786                assert isinstance(
787                    field, SerializableField
788                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
789
790                # check if the field is in the data and if it should be initialized
791                if (field.name in data) and field.init:
792                    # get the value, we will be processing it
793                    value: Any = data[field.name]
794
795                    # get the type hint for the field
796                    field_type_hint: Any = cls_type_hints.get(field.name, None)
797
798                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
799                    if field.deserialize_fn:
800                        # if it has a deserialization function, use that
801                        value = field.deserialize_fn(value)
802                    elif field.loading_fn:
803                        # if it has a loading function, use that
804                        value = field.loading_fn(data)
805                    elif (
806                        field_type_hint is not None
807                        and hasattr(field_type_hint, "load")
808                        and callable(field_type_hint.load)
809                    ):
810                        # if no loading function but has a type hint with a load method, use that
811                        if isinstance(value, dict):
812                            value = field_type_hint.load(value)
813                        else:
814                            raise FieldLoadingError(
815                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
816                            )
817                    else:
818                        # assume no loading needs to happen, keep `value` as-is
819                        pass
820
821                    # store the value in the constructor kwargs
822                    ctor_kwargs[field.name] = value
823
824            # create a new instance of the class with the constructor kwargs
825            output: cls = cls(**ctor_kwargs)
826
827            # validate the types of the fields if needed
828            if on_typecheck_mismatch != ErrorMode.IGNORE:
829                fields_valid: dict[str, bool] = (
830                    SerializableDataclass__validate_fields_types__dict(
831                        output,
832                        on_typecheck_error=on_typecheck_error,
833                    )
834                )
835
836                # if there are any fields that are not valid, raise an error
837                if not all(fields_valid.values()):
838                    msg: str = (
839                        f"Type mismatch in fields of {cls.__name__}:\n"
840                        + "\n".join(
841                            [
842                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
843                                for k, v in fields_valid.items()
844                                if not v
845                            ]
846                        )
847                    )
848
849                    on_typecheck_mismatch.process(
850                        msg, except_cls=FieldTypeMismatchError
851                    )
852
853            # return the new instance
854            return output
855
856        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
857        # type is `Callable[[T], dict]`
858        cls.serialize = serialize  # type: ignore[attr-defined]
859        # type is `Callable[[dict], T]`
860        cls.load = load  # type: ignore[attr-defined]
861        # type is `Callable[[T, ErrorMode], bool]`
862        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
863
864        # type is `Callable[[T, T], bool]`
865        if not hasattr(cls, "__eq__"):
866            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
867
868        # Register the class with ZANJ
869        if register_handler:
870            zanj_register_loader_serializable_dataclass(cls)
871
872        return cls
873
874    if _cls is None:
875        return wrap
876    else:
877        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool (defaults to True)
  • repr : bool (defaults to True)
  • order : bool (defaults to False)
  • unsafe_hash : bool (defaults to False)
  • frozen : bool (defaults to False)
  • properties_to_serialize : Optional[list[str]] SerializableDataclass only: which properties to add to the serialized data dict (defaults to None)
  • register_handler : bool SerializableDataclass only: if true, register the class with ZANJ for loading (defaults to True)
  • on_typecheck_error : ErrorMode SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false
  • on_typecheck_mismatch : ErrorMode SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True

Returns:

  • _type_ the decorated class

Raises:

  • KWOnlyError : only raised if kw_only is True and python version is <3.9, since dataclasses.dataclass does not support this
  • NotSerializableFieldException : if a field is not a SerializableField
  • FieldSerializationError : if there is an error serializing a field
  • AttributeError : if a property is not found on the class
  • FieldLoadingError : if there is an error loading a field
def serializable_field( *_args, default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[mappingproxy] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any) -> Any:
188def serializable_field(
189    *_args,
190    default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
191    default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
192    init: bool = True,
193    repr: bool = True,
194    hash: Optional[bool] = None,
195    compare: bool = True,
196    metadata: Optional[types.MappingProxyType] = None,
197    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
198    serialize: bool = True,
199    serialization_fn: Optional[Callable[[Any], Any]] = None,
200    deserialize_fn: Optional[Callable[[Any], Any]] = None,
201    assert_type: bool = True,
202    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
203    **kwargs: Any,
204) -> Any:
205    """Create a new `SerializableField`
206
207    ```
208    default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
209    default_factory: Callable[[], Sfield_T]
210    | dataclasses._MISSING_TYPE = dataclasses.MISSING,
211    init: bool = True,
212    repr: bool = True,
213    hash: Optional[bool] = None,
214    compare: bool = True,
215    metadata: types.MappingProxyType | None = None,
216    kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
217    # ----------------------------------------------------------------------
218    # new in `SerializableField`, not in `dataclasses.Field`
219    serialize: bool = True,
220    serialization_fn: Optional[Callable[[Any], Any]] = None,
221    loading_fn: Optional[Callable[[Any], Any]] = None,
222    deserialize_fn: Optional[Callable[[Any], Any]] = None,
223    assert_type: bool = True,
224    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
225    ```
226
227    # new Parameters:
228    - `serialize`: whether to serialize this field when serializing the class'
229    - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize`
230    - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
231    - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised.
232    - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field.
233    - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
234
235    # Gotchas:
236    - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write:
237
238    ```python
239    class MyClass:
240        my_field: int = serializable_field(
241            serialization_fn=lambda x: str(x),
242            loading_fn=lambda x["my_field"]: int(x)
243        )
244    ```
245
246    using `deserialize_fn` instead:
247
248    ```python
249    class MyClass:
250        my_field: int = serializable_field(
251            serialization_fn=lambda x: str(x),
252            deserialize_fn=lambda x: int(x)
253        )
254    ```
255
256    In the above code, `my_field` is an int but will be serialized as a string.
257
258    note that if not using ZANJ, and you have a class inside a container, you MUST provide
259    `serialization_fn` and `loading_fn` to serialize and load the container.
260    ZANJ will automatically do this for you.
261
262    # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
263    """
264    assert len(_args) == 0, f"unexpected positional arguments: {_args}"
265    return SerializableField(
266        default=default,
267        default_factory=default_factory,
268        init=init,
269        repr=repr,
270        hash=hash,
271        compare=compare,
272        metadata=metadata,
273        kw_only=kw_only,
274        serialize=serialize,
275        serialization_fn=serialization_fn,
276        deserialize_fn=deserialize_fn,
277        assert_type=assert_type,
278        custom_typecheck_fn=custom_typecheck_fn,
279        **kwargs,
280    )

Create a new SerializableField

default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,

new Parameters:

  • serialize: whether to serialize this field when serializing the class'
  • serialization_fn: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the SerializerHandlers defined in muutils.json_serialize.json_serialize
  • loading_fn: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
  • deserialize_fn: new alternative to loading_fn. takes only the field's value, not the whole class. if both loading_fn and deserialize_fn are provided, an error will be raised.
  • assert_type: whether to assert the type of the field when loading. if False, will not check the type of the field.
  • custom_typecheck_fn: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.

Gotchas:

  • loading_fn takes the dict of the class, not the field. if you wanted a loading_fn that does nothing, you'd write:
class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        loading_fn=lambda x["my_field"]: int(x)
    )

using deserialize_fn instead:

class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: int(x)
    )

In the above code, my_field is an int but will be serialized as a string.

note that if not using ZANJ, and you have a class inside a container, you MUST provide serialization_fn and loading_fn to serialize and load the container. ZANJ will automatically do this for you.

TODO: custom_value_check_fn: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test

def arr_metadata(arr) -> dict[str, list[int] | str | int]:
49def arr_metadata(arr) -> dict[str, list[int] | str | int]:
50    """get metadata for a numpy array"""
51    return {
52        "shape": list(arr.shape),
53        "dtype": (
54            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
55        ),
56        "n_elements": array_n_elements(arr),
57    }

get metadata for a numpy array

def load_array( arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Any:
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
169    """load a json-serialized array, infer the mode if not specified"""
170    # return arr if its already a numpy array
171    if isinstance(arr, np.ndarray) and array_mode is None:
172        return arr
173
174    # try to infer the array_mode
175    array_mode_inferred: ArrayMode = infer_array_mode(arr)
176    if array_mode is None:
177        array_mode = array_mode_inferred
178    elif array_mode != array_mode_inferred:
179        warnings.warn(
180            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
181        )
182
183    # actually load the array
184    if array_mode == "array_list_meta":
185        assert isinstance(
186            arr, typing.Mapping
187        ), f"invalid list format: {type(arr) = }\n{arr = }"
188
189        data = np.array(arr["data"], dtype=arr["dtype"])
190        if tuple(arr["shape"]) != tuple(data.shape):
191            raise ValueError(f"invalid shape: {arr}")
192        return data
193
194    elif array_mode == "array_hex_meta":
195        assert isinstance(
196            arr, typing.Mapping
197        ), f"invalid list format: {type(arr) = }\n{arr = }"
198
199        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])
200        return data.reshape(arr["shape"])
201
202    elif array_mode == "array_b64_meta":
203        assert isinstance(
204            arr, typing.Mapping
205        ), f"invalid list format: {type(arr) = }\n{arr = }"
206
207        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])
208        return data.reshape(arr["shape"])
209
210    elif array_mode == "list":
211        assert isinstance(
212            arr, typing.Sequence
213        ), f"invalid list format: {type(arr) = }\n{arr = }"
214
215        return np.array(arr)
216    elif array_mode == "external":
217        # assume ZANJ has taken care of it
218        assert isinstance(arr, typing.Mapping)
219        if "data" not in arr:
220            raise KeyError(
221                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
222            )
223        return arr["data"]
224    elif array_mode == "zero_dim":
225        assert isinstance(arr, typing.Mapping)
226        data = np.array(arr["data"])
227        if tuple(arr["shape"]) != tuple(data.shape):
228            raise ValueError(f"invalid shape: {arr}")
229        return data
230    else:
231        raise ValueError(f"invalid array_mode: {array_mode}")

load a json-serialized array, infer the mode if not specified

BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, list, typing.Dict[str, typing.Any], NoneType]
class JsonSerializer:
234class JsonSerializer:
235    """Json serialization class (holds configs)
236
237    # Parameters:
238    - `array_mode : ArrayMode`
239    how to write arrays
240    (defaults to `"array_list_meta"`)
241    - `error_mode : ErrorMode`
242    what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
243    (defaults to `"except"`)
244    - `handlers_pre : MonoTuple[SerializerHandler]`
245    handlers to use before the default handlers
246    (defaults to `tuple()`)
247    - `handlers_default : MonoTuple[SerializerHandler]`
248    default handlers to use
249    (defaults to `DEFAULT_HANDLERS`)
250    - `write_only_format : bool`
251    changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
252    (defaults to `False`)
253
254    # Raises:
255    - `ValueError`: on init, if `args` is not empty
256    - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
257
258    """
259
260    def __init__(
261        self,
262        *args,
263        array_mode: ArrayMode = "array_list_meta",
264        error_mode: ErrorMode = ErrorMode.EXCEPT,
265        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
266        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
267        write_only_format: bool = False,
268    ):
269        if len(args) > 0:
270            raise ValueError(
271                f"JsonSerializer takes no positional arguments!\n{args = }"
272            )
273
274        self.array_mode: ArrayMode = array_mode
275        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
276        self.write_only_format: bool = write_only_format
277        # join up the handlers
278        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
279            handlers_default
280        )
281
282    def json_serialize(
283        self,
284        obj: Any,
285        path: ObjectPath = tuple(),
286    ) -> JSONitem:
287        try:
288            for handler in self.handlers:
289                if handler.check(self, obj, path):
290                    output: JSONitem = handler.serialize_func(self, obj, path)
291                    if self.write_only_format:
292                        if isinstance(output, dict) and "__format__" in output:
293                            new_fmt: JSONitem = output.pop("__format__")
294                            output["__write_format__"] = new_fmt
295                    return output
296
297            raise ValueError(f"no handler found for object with {type(obj) = }")
298
299        except Exception as e:
300            if self.error_mode == "except":
301                obj_str: str = repr(obj)
302                if len(obj_str) > 1000:
303                    obj_str = obj_str[:1000] + "..."
304                raise SerializationException(
305                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
306                ) from e
307            elif self.error_mode == "warn":
308                warnings.warn(
309                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
310                )
311
312            return repr(obj)
313
314    def hashify(
315        self,
316        obj: Any,
317        path: ObjectPath = tuple(),
318        force: bool = True,
319    ) -> Hashableitem:
320        """try to turn any object into something hashable"""
321        data = self.json_serialize(obj, path=path)
322
323        # recursive hashify, turning dicts and lists into tuples
324        return _recursive_hashify(data, force=force)

Json serialization class (holds configs)

Parameters:

  • array_mode : ArrayMode how to write arrays (defaults to "array_list_meta")
  • error_mode : ErrorMode what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to "except")
  • handlers_pre : MonoTuple[SerializerHandler] handlers to use before the default handlers (defaults to tuple())
  • handlers_default : MonoTuple[SerializerHandler] default handlers to use (defaults to DEFAULT_HANDLERS)
  • write_only_format : bool changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults to False)

Raises:

  • ValueError: on init, if args is not empty
  • SerializationException: on json_serialize(), if any error occurs when trying to serialize an object and error_mode is set to ErrorMode.EXCEPT"
JsonSerializer( *args, array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta', error_mode: muutils.errormode.ErrorMode = ErrorMode.Except, handlers_pre: None = (), handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')), write_only_format: bool = False)
260    def __init__(
261        self,
262        *args,
263        array_mode: ArrayMode = "array_list_meta",
264        error_mode: ErrorMode = ErrorMode.EXCEPT,
265        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
266        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
267        write_only_format: bool = False,
268    ):
269        if len(args) > 0:
270            raise ValueError(
271                f"JsonSerializer takes no positional arguments!\n{args = }"
272            )
273
274        self.array_mode: ArrayMode = array_mode
275        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
276        self.write_only_format: bool = write_only_format
277        # join up the handlers
278        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
279            handlers_default
280        )
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
write_only_format: bool
handlers: None
def json_serialize( self, obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]:
282    def json_serialize(
283        self,
284        obj: Any,
285        path: ObjectPath = tuple(),
286    ) -> JSONitem:
287        try:
288            for handler in self.handlers:
289                if handler.check(self, obj, path):
290                    output: JSONitem = handler.serialize_func(self, obj, path)
291                    if self.write_only_format:
292                        if isinstance(output, dict) and "__format__" in output:
293                            new_fmt: JSONitem = output.pop("__format__")
294                            output["__write_format__"] = new_fmt
295                    return output
296
297            raise ValueError(f"no handler found for object with {type(obj) = }")
298
299        except Exception as e:
300            if self.error_mode == "except":
301                obj_str: str = repr(obj)
302                if len(obj_str) > 1000:
303                    obj_str = obj_str[:1000] + "..."
304                raise SerializationException(
305                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
306                ) from e
307            elif self.error_mode == "warn":
308                warnings.warn(
309                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
310                )
311
312            return repr(obj)
def hashify( self, obj: Any, path: tuple[typing.Union[str, int], ...] = (), force: bool = True) -> Union[bool, int, float, str, tuple]:
314    def hashify(
315        self,
316        obj: Any,
317        path: ObjectPath = tuple(),
318        force: bool = True,
319    ) -> Hashableitem:
320        """try to turn any object into something hashable"""
321        data = self.json_serialize(obj, path=path)
322
323        # recursive hashify, turning dicts and lists into tuples
324        return _recursive_hashify(data, force=force)

try to turn any object into something hashable

def try_catch(func: Callable):
81def try_catch(func: Callable):
82    """wraps the function to catch exceptions, returns serialized error message on exception
83
84    returned func will return normal result on success, or error message on exception
85    """
86
87    @functools.wraps(func)
88    def newfunc(*args, **kwargs):
89        try:
90            return func(*args, **kwargs)
91        except Exception as e:
92            return f"{e.__class__.__name__}: {e}"
93
94    return newfunc

wraps the function to catch exceptions, returns serialized error message on exception

returned func will return normal result on success, or error message on exception

def dc_eq( dc1, dc2, except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False) -> bool:
175def dc_eq(
176    dc1,
177    dc2,
178    except_when_class_mismatch: bool = False,
179    false_when_class_mismatch: bool = True,
180    except_when_field_mismatch: bool = False,
181) -> bool:
182    """
183    checks if two dataclasses which (might) hold numpy arrays are equal
184
185    # Parameters:
186
187    - `dc1`: the first dataclass
188    - `dc2`: the second dataclass
189    - `except_when_class_mismatch: bool`
190        if `True`, will throw `TypeError` if the classes are different.
191        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
192        (default: `False`)
193    - `false_when_class_mismatch: bool`
194        only relevant if `except_when_class_mismatch` is `False`.
195        if `True`, will return `False` if the classes are different.
196        if `False`, will attempt to compare the fields.
197    - `except_when_field_mismatch: bool`
198        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
199        if `True`, will throw `TypeError` if the fields are different.
200        (default: `True`)
201
202    # Returns:
203    - `bool`: True if the dataclasses are equal, False otherwise
204
205    # Raises:
206    - `TypeError`: if the dataclasses are of different classes
207    - `AttributeError`: if the dataclasses have different fields
208
209    # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
210    ```
211              [START]
212
213           ┌───────────┐  ┌─────────┐
214           │dc1 is dc2?├─►│ classes │
215           └──┬────────┘No│ match?  │
216      ────    │           ├─────────┤
217     (True)◄──┘Yes        │No       │Yes
218      ────                ▼         ▼
219          ┌────────────────┐ ┌────────────┐
220          │ except when    │ │ fields keys│
221          │ class mismatch?│ │ match?     │
222          ├───────────┬────┘ ├───────┬────┘
223          │Yes        │No    │No     │Yes
224          ▼           ▼      ▼       ▼
225     ───────────  ┌──────────┐  ┌────────┐
226    { raise     } │ except   │  │ field  │
227    { TypeError } │ when     │  │ values │
228     ───────────  │ field    │  │ match? │
229                  │ mismatch?│  ├────┬───┘
230                  ├───────┬──┘  │    │Yes
231                  │Yes    │No   │No  ▼
232                  ▼       ▼     │   ────
233     ───────────────     ─────  │  (True)
234    { raise         }   (False)◄┘   ────
235    { AttributeError}    ─────
236     ───────────────
237    ```
238
239    """
240    if dc1 is dc2:
241        return True
242
243    if dc1.__class__ is not dc2.__class__:
244        if except_when_class_mismatch:
245            # if the classes don't match, raise an error
246            raise TypeError(
247                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
248            )
249        if except_when_field_mismatch:
250            dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
251            dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
252            fields_match: bool = set(dc1_fields) == set(dc2_fields)
253            if not fields_match:
254                # if the fields match, keep going
255                raise AttributeError(
256                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
257                )
258        return False
259
260    return all(
261        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
262        for fld in dataclasses.fields(dc1)
263        if fld.compare
264    )

checks if two dataclasses which (might) hold numpy arrays are equal

Parameters:

  • dc1: the first dataclass
  • dc2: the second dataclass
  • except_when_class_mismatch: bool if True, will throw TypeError if the classes are different. if not, will return false by default or attempt to compare the fields if false_when_class_mismatch is False (default: False)
  • false_when_class_mismatch: bool only relevant if except_when_class_mismatch is False. if True, will return False if the classes are different. if False, will attempt to compare the fields.
  • except_when_field_mismatch: bool only relevant if except_when_class_mismatch is False and false_when_class_mismatch is False. if True, will throw TypeError if the fields are different. (default: True)

Returns:

  • bool: True if the dataclasses are equal, False otherwise

Raises:

  • TypeError: if the dataclasses are of different classes
  • AttributeError: if the dataclasses have different fields

TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?

          [START]
             ▼
       ┌───────────┐  ┌─────────┐
       │dc1 is dc2?├─►│ classes │
       └──┬────────┘No│ match?  │
  ────    │           ├─────────┤
 (True)◄──┘Yes        │No       │Yes
  ────                ▼         ▼
      ┌────────────────┐ ┌────────────┐
      │ except when    │ │ fields keys│
      │ class mismatch?│ │ match?     │
      ├───────────┬────┘ ├───────┬────┘
      │Yes        │No    │No     │Yes
      ▼           ▼      ▼       ▼
 ───────────  ┌──────────┐  ┌────────┐
{ raise     } │ except   │  │ field  │
{ TypeError } │ when     │  │ values │
 ───────────  │ field    │  │ match? │
              │ mismatch?│  ├────┬───┘
              ├───────┬──┘  │    │Yes
              │Yes    │No   │No  ▼
              ▼       ▼     │   ────
 ───────────────     ─────  │  (True)
{ raise         }   (False)◄┘   ────
{ AttributeError}    ─────
 ───────────────
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
317@dataclass_transform(
318    field_specifiers=(serializable_field, SerializableField),
319)
320class SerializableDataclass(abc.ABC):
321    """Base class for serializable dataclasses
322
323    only for linting and type checking, still need to call `serializable_dataclass` decorator
324
325    # Usage:
326
327    ```python
328    @serializable_dataclass
329    class MyClass(SerializableDataclass):
330        a: int
331        b: str
332    ```
333
334    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
335
336        >>> my_obj = MyClass(a=1, b="q")
337        >>> s = json.dumps(my_obj.serialize())
338        >>> s
339        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
340        >>> read_obj = MyClass.load(json.loads(s))
341        >>> read_obj == my_obj
342        True
343
344    This isn't too impressive on its own, but it gets more useful when you have nested classses,
345    or fields that are not json-serializable by default:
346
347    ```python
348    @serializable_dataclass
349    class NestedClass(SerializableDataclass):
350        x: str
351        y: MyClass
352        act_fun: torch.nn.Module = serializable_field(
353            default=torch.nn.ReLU(),
354            serialization_fn=lambda x: str(x),
355            deserialize_fn=lambda x: getattr(torch.nn, x)(),
356        )
357    ```
358
359    which gives us:
360
361        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
362        >>> s = json.dumps(nc.serialize())
363        >>> s
364        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
365        >>> read_nc = NestedClass.load(json.loads(s))
366        >>> read_nc == nc
367        True
368    """
369
370    def serialize(self) -> dict[str, Any]:
371        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
372        raise NotImplementedError(
373            f"decorate {self.__class__ = } with `@serializable_dataclass`"
374        )
375
376    @classmethod
377    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
378        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
379        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
380
381    def validate_fields_types(
382        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
383    ) -> bool:
384        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
385        return SerializableDataclass__validate_fields_types(
386            self, on_typecheck_error=on_typecheck_error
387        )
388
389    def validate_field_type(
390        self,
391        field: "SerializableField|str",
392        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
393    ) -> bool:
394        """given a dataclass, check the field matches the type hint"""
395        return SerializableDataclass__validate_field_type(
396            self, field, on_typecheck_error=on_typecheck_error
397        )
398
399    def __eq__(self, other: Any) -> bool:
400        return dc_eq(self, other)
401
402    def __hash__(self) -> int:
403        "hashes the json-serialized representation of the class"
404        return hash(json.dumps(self.serialize()))
405
406    def diff(
407        self, other: "SerializableDataclass", of_serialized: bool = False
408    ) -> dict[str, Any]:
409        """get a rich and recursive diff between two instances of a serializable dataclass
410
411        ```python
412        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
413        {'b': {'self': 2, 'other': 3}}
414        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
415        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
416        ```
417
418        # Parameters:
419         - `other : SerializableDataclass`
420           other instance to compare against
421         - `of_serialized : bool`
422           if true, compare serialized data and not raw values
423           (defaults to `False`)
424
425        # Returns:
426         - `dict[str, Any]`
427
428
429        # Raises:
430         - `ValueError` : if the instances are not of the same type
431         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
432        """
433        # match types
434        if type(self) is not type(other):
435            raise ValueError(
436                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
437            )
438
439        # initialize the diff result
440        diff_result: dict = {}
441
442        # if they are the same, return the empty diff
443        if self == other:
444            return diff_result
445
446        # if we are working with serialized data, serialize the instances
447        if of_serialized:
448            ser_self: dict = self.serialize()
449            ser_other: dict = other.serialize()
450
451        # for each field in the class
452        for field in dataclasses.fields(self):  # type: ignore[arg-type]
453            # skip fields that are not for comparison
454            if not field.compare:
455                continue
456
457            # get values
458            field_name: str = field.name
459            self_value = getattr(self, field_name)
460            other_value = getattr(other, field_name)
461
462            # if the values are both serializable dataclasses, recurse
463            if isinstance(self_value, SerializableDataclass) and isinstance(
464                other_value, SerializableDataclass
465            ):
466                nested_diff: dict = self_value.diff(
467                    other_value, of_serialized=of_serialized
468                )
469                if nested_diff:
470                    diff_result[field_name] = nested_diff
471            # only support serializable dataclasses
472            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
473                other_value
474            ):
475                raise ValueError("Non-serializable dataclass is not supported")
476            else:
477                # get the values of either the serialized or the actual values
478                self_value_s = ser_self[field_name] if of_serialized else self_value
479                other_value_s = ser_other[field_name] if of_serialized else other_value
480                # compare the values
481                if not array_safe_eq(self_value_s, other_value_s):
482                    diff_result[field_name] = {"self": self_value, "other": other_value}
483
484        # return the diff result
485        return diff_result
486
487    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
488        """update the instance from a nested dict, useful for configuration from command line args
489
490        # Parameters:
491            - `nested_dict : dict[str, Any]`
492                nested dict to update the instance with
493        """
494        for field in dataclasses.fields(self):  # type: ignore[arg-type]
495            field_name: str = field.name
496            self_value = getattr(self, field_name)
497
498            if field_name in nested_dict:
499                if isinstance(self_value, SerializableDataclass):
500                    self_value.update_from_nested_dict(nested_dict[field_name])
501                else:
502                    setattr(self, field_name, nested_dict[field_name])
503
504    def __copy__(self) -> "SerializableDataclass":
505        "deep copy by serializing and loading the instance to json"
506        return self.__class__.load(json.loads(json.dumps(self.serialize())))
507
508    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
509        "deep copy by serializing and loading the instance to json"
510        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
370    def serialize(self) -> dict[str, Any]:
371        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
372        raise NotImplementedError(
373            f"decorate {self.__class__ = } with `@serializable_dataclass`"
374        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
376    @classmethod
377    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
378        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
379        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
381    def validate_fields_types(
382        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
383    ) -> bool:
384        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
385        return SerializableDataclass__validate_fields_types(
386            self, on_typecheck_error=on_typecheck_error
387        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
389    def validate_field_type(
390        self,
391        field: "SerializableField|str",
392        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
393    ) -> bool:
394        """given a dataclass, check the field matches the type hint"""
395        return SerializableDataclass__validate_field_type(
396            self, field, on_typecheck_error=on_typecheck_error
397        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
406    def diff(
407        self, other: "SerializableDataclass", of_serialized: bool = False
408    ) -> dict[str, Any]:
409        """get a rich and recursive diff between two instances of a serializable dataclass
410
411        ```python
412        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
413        {'b': {'self': 2, 'other': 3}}
414        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
415        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
416        ```
417
418        # Parameters:
419         - `other : SerializableDataclass`
420           other instance to compare against
421         - `of_serialized : bool`
422           if true, compare serialized data and not raw values
423           (defaults to `False`)
424
425        # Returns:
426         - `dict[str, Any]`
427
428
429        # Raises:
430         - `ValueError` : if the instances are not of the same type
431         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
432        """
433        # match types
434        if type(self) is not type(other):
435            raise ValueError(
436                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
437            )
438
439        # initialize the diff result
440        diff_result: dict = {}
441
442        # if they are the same, return the empty diff
443        if self == other:
444            return diff_result
445
446        # if we are working with serialized data, serialize the instances
447        if of_serialized:
448            ser_self: dict = self.serialize()
449            ser_other: dict = other.serialize()
450
451        # for each field in the class
452        for field in dataclasses.fields(self):  # type: ignore[arg-type]
453            # skip fields that are not for comparison
454            if not field.compare:
455                continue
456
457            # get values
458            field_name: str = field.name
459            self_value = getattr(self, field_name)
460            other_value = getattr(other, field_name)
461
462            # if the values are both serializable dataclasses, recurse
463            if isinstance(self_value, SerializableDataclass) and isinstance(
464                other_value, SerializableDataclass
465            ):
466                nested_diff: dict = self_value.diff(
467                    other_value, of_serialized=of_serialized
468                )
469                if nested_diff:
470                    diff_result[field_name] = nested_diff
471            # only support serializable dataclasses
472            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
473                other_value
474            ):
475                raise ValueError("Non-serializable dataclass is not supported")
476            else:
477                # get the values of either the serialized or the actual values
478                self_value_s = ser_self[field_name] if of_serialized else self_value
479                other_value_s = ser_other[field_name] if of_serialized else other_value
480                # compare the values
481                if not array_safe_eq(self_value_s, other_value_s):
482                    diff_result[field_name] = {"self": self_value, "other": other_value}
483
484        # return the diff result
485        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
487    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
488        """update the instance from a nested dict, useful for configuration from command line args
489
490        # Parameters:
491            - `nested_dict : dict[str, Any]`
492                nested dict to update the instance with
493        """
494        for field in dataclasses.fields(self):  # type: ignore[arg-type]
495            field_name: str = field.name
496            self_value = getattr(self, field_name)
497
498            if field_name in nested_dict:
499                if isinstance(self_value, SerializableDataclass):
500                    self_value.update_from_nested_dict(nested_dict[field_name])
501                else:
502                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with