muutils.json_serialize
submodule for serializing things to json in a recoverable way
you can throw any object into muutils.json_serialize.json_serialize
and it will return a JSONitem
, meaning a bool, int, float, str, None, list of JSONitem
s, or a dict mappting to JSONitem
.
The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize
and it will just work. If you want to do so in a recoverable way, check out ZANJ
.
it will do so by looking in DEFAULT_HANDLERS
, which will keep it as-is if its already valid, then try to find a .serialize()
method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer
object and passing a sequence of them to handlers_pre
additionally, SerializeableDataclass
is a special kind of dataclass where you specify how to serialize each field, and a .serialize()
method is automatically added to the class. This is done by using the serializable_dataclass
decorator, inheriting from SerializeableDataclass
, and serializable_field
in place of dataclasses.field
when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ
library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
1"""submodule for serializing things to json in a recoverable way 2 3you can throw *any* object into `muutils.json_serialize.json_serialize` 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`. 5 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ). 7 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre` 9 10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields. 11 12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes. 13 14""" 15 16from __future__ import annotations 17 18from muutils.json_serialize.array import arr_metadata, load_array 19from muutils.json_serialize.json_serialize import ( 20 BASE_HANDLERS, 21 JsonSerializer, 22 json_serialize, 23) 24from muutils.json_serialize.serializable_dataclass import ( 25 SerializableDataclass, 26 serializable_dataclass, 27 serializable_field, 28) 29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq 30 31__all__ = [ 32 # submodules 33 "array", 34 "json_serialize", 35 "serializable_dataclass", 36 "serializable_field", 37 "util", 38 # imports 39 "arr_metadata", 40 "load_array", 41 "BASE_HANDLERS", 42 "JSONitem", 43 "JsonSerializer", 44 "json_serialize", 45 "try_catch", 46 "JSONitem", 47 "dc_eq", 48 "serializable_dataclass", 49 "serializable_field", 50 "SerializableDataclass", 51]
330def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: 331 """serialize object to json-serializable object with default config""" 332 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)
serialize object to json-serializable object with default config
578@dataclass_transform( 579 field_specifiers=(serializable_field, SerializableField), 580) 581def serializable_dataclass( 582 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 583 _cls=None, # type: ignore 584 *, 585 init: bool = True, 586 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 587 eq: bool = True, 588 order: bool = False, 589 unsafe_hash: bool = False, 590 frozen: bool = False, 591 properties_to_serialize: Optional[list[str]] = None, 592 register_handler: bool = True, 593 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 594 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 595 **kwargs, 596): 597 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 598 599 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 600 601 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 602 603 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 604 605 Examines PEP 526 `__annotations__` to determine fields. 606 607 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 608 609 ```python 610 @serializable_dataclass(kw_only=True) 611 class Myclass(SerializableDataclass): 612 a: int 613 b: str 614 ``` 615 ```python 616 >>> Myclass(a=1, b="q").serialize() 617 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 618 ``` 619 620 # Parameters: 621 - `_cls : _type_` 622 class to decorate. don't pass this arg, just use this as a decorator 623 (defaults to `None`) 624 - `init : bool` 625 (defaults to `True`) 626 - `repr : bool` 627 (defaults to `True`) 628 - `order : bool` 629 (defaults to `False`) 630 - `unsafe_hash : bool` 631 (defaults to `False`) 632 - `frozen : bool` 633 (defaults to `False`) 634 - `properties_to_serialize : Optional[list[str]]` 635 **SerializableDataclass only:** which properties to add to the serialized data dict 636 (defaults to `None`) 637 - `register_handler : bool` 638 **SerializableDataclass only:** if true, register the class with ZANJ for loading 639 (defaults to `True`) 640 - `on_typecheck_error : ErrorMode` 641 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 642 - `on_typecheck_mismatch : ErrorMode` 643 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 644 645 # Returns: 646 - `_type_` 647 the decorated class 648 649 # Raises: 650 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 651 - `NotSerializableFieldException` : if a field is not a `SerializableField` 652 - `FieldSerializationError` : if there is an error serializing a field 653 - `AttributeError` : if a property is not found on the class 654 - `FieldLoadingError` : if there is an error loading a field 655 """ 656 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 657 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 658 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 659 660 if properties_to_serialize is None: 661 _properties_to_serialize: list = list() 662 else: 663 _properties_to_serialize = properties_to_serialize 664 665 def wrap(cls: Type[T]) -> Type[T]: 666 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 667 for field_name, field_type in cls.__annotations__.items(): 668 field_value = getattr(cls, field_name, None) 669 if not isinstance(field_value, SerializableField): 670 if isinstance(field_value, dataclasses.Field): 671 # Convert the field to a SerializableField while preserving properties 672 field_value = SerializableField.from_Field(field_value) 673 else: 674 # Create a new SerializableField 675 field_value = serializable_field() 676 setattr(cls, field_name, field_value) 677 678 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 679 if sys.version_info < (3, 10): 680 if "kw_only" in kwargs: 681 if kwargs["kw_only"] == True: # noqa: E712 682 raise KWOnlyError("kw_only is not supported in python >=3.9") 683 else: 684 del kwargs["kw_only"] 685 686 # call `dataclasses.dataclass` to set some stuff up 687 cls = dataclasses.dataclass( # type: ignore[call-overload] 688 cls, 689 init=init, 690 repr=repr, 691 eq=eq, 692 order=order, 693 unsafe_hash=unsafe_hash, 694 frozen=frozen, 695 **kwargs, 696 ) 697 698 # copy these to the class 699 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 700 701 # ====================================================================== 702 # define `serialize` func 703 # done locally since it depends on args to the decorator 704 # ====================================================================== 705 def serialize(self) -> dict[str, Any]: 706 result: dict[str, Any] = { 707 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 708 } 709 # for each field in the class 710 for field in dataclasses.fields(self): # type: ignore[arg-type] 711 # need it to be our special SerializableField 712 if not isinstance(field, SerializableField): 713 raise NotSerializableFieldException( 714 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 715 f"but a {type(field)} " 716 "this state should be inaccessible, please report this bug!" 717 ) 718 719 # try to save it 720 if field.serialize: 721 try: 722 # get the val 723 value = getattr(self, field.name) 724 # if it is a serializable dataclass, serialize it 725 if isinstance(value, SerializableDataclass): 726 value = value.serialize() 727 # if the value has a serialization function, use that 728 if hasattr(value, "serialize") and callable(value.serialize): 729 value = value.serialize() 730 # if the field has a serialization function, use that 731 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 732 elif field.serialization_fn: 733 value = field.serialization_fn(value) 734 735 # store the value in the result 736 result[field.name] = value 737 except Exception as e: 738 raise FieldSerializationError( 739 "\n".join( 740 [ 741 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 742 f"{field = }", 743 f"{value = }", 744 f"{self = }", 745 ] 746 ) 747 ) from e 748 749 # store each property if we can get it 750 for prop in self._properties_to_serialize: 751 if hasattr(cls, prop): 752 value = getattr(self, prop) 753 result[prop] = value 754 else: 755 raise AttributeError( 756 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 757 + f"but it is in {self._properties_to_serialize = }" 758 + f"\n{self = }" 759 ) 760 761 return result 762 763 # ====================================================================== 764 # define `load` func 765 # done locally since it depends on args to the decorator 766 # ====================================================================== 767 # mypy thinks this isnt a classmethod 768 @classmethod # type: ignore[misc] 769 def load(cls, data: dict[str, Any] | T) -> Type[T]: 770 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 771 if isinstance(data, cls): 772 return data 773 774 assert isinstance( 775 data, typing.Mapping 776 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 777 778 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 779 780 # initialize dict for keeping what we will pass to the constructor 781 ctor_kwargs: dict[str, Any] = dict() 782 783 # iterate over the fields of the class 784 for field in dataclasses.fields(cls): 785 # check if the field is a SerializableField 786 assert isinstance( 787 field, SerializableField 788 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 789 790 # check if the field is in the data and if it should be initialized 791 if (field.name in data) and field.init: 792 # get the value, we will be processing it 793 value: Any = data[field.name] 794 795 # get the type hint for the field 796 field_type_hint: Any = cls_type_hints.get(field.name, None) 797 798 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 799 if field.deserialize_fn: 800 # if it has a deserialization function, use that 801 value = field.deserialize_fn(value) 802 elif field.loading_fn: 803 # if it has a loading function, use that 804 value = field.loading_fn(data) 805 elif ( 806 field_type_hint is not None 807 and hasattr(field_type_hint, "load") 808 and callable(field_type_hint.load) 809 ): 810 # if no loading function but has a type hint with a load method, use that 811 if isinstance(value, dict): 812 value = field_type_hint.load(value) 813 else: 814 raise FieldLoadingError( 815 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 816 ) 817 else: 818 # assume no loading needs to happen, keep `value` as-is 819 pass 820 821 # store the value in the constructor kwargs 822 ctor_kwargs[field.name] = value 823 824 # create a new instance of the class with the constructor kwargs 825 output: cls = cls(**ctor_kwargs) 826 827 # validate the types of the fields if needed 828 if on_typecheck_mismatch != ErrorMode.IGNORE: 829 fields_valid: dict[str, bool] = ( 830 SerializableDataclass__validate_fields_types__dict( 831 output, 832 on_typecheck_error=on_typecheck_error, 833 ) 834 ) 835 836 # if there are any fields that are not valid, raise an error 837 if not all(fields_valid.values()): 838 msg: str = ( 839 f"Type mismatch in fields of {cls.__name__}:\n" 840 + "\n".join( 841 [ 842 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 843 for k, v in fields_valid.items() 844 if not v 845 ] 846 ) 847 ) 848 849 on_typecheck_mismatch.process( 850 msg, except_cls=FieldTypeMismatchError 851 ) 852 853 # return the new instance 854 return output 855 856 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 857 # type is `Callable[[T], dict]` 858 cls.serialize = serialize # type: ignore[attr-defined] 859 # type is `Callable[[dict], T]` 860 cls.load = load # type: ignore[attr-defined] 861 # type is `Callable[[T, ErrorMode], bool]` 862 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 863 864 # type is `Callable[[T, T], bool]` 865 if not hasattr(cls, "__eq__"): 866 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 867 868 # Register the class with ZANJ 869 if register_handler: 870 zanj_register_loader_serializable_dataclass(cls) 871 872 return cls 873 874 if _cls is None: 875 return wrap 876 else: 877 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
(defaults toTrue
)repr : bool
(defaults toTrue
)order : bool
(defaults toFalse
)unsafe_hash : bool
(defaults toFalse
)frozen : bool
(defaults toFalse
)properties_to_serialize : Optional[list[str]]
SerializableDataclass only: which properties to add to the serialized data dict (defaults toNone
)register_handler : bool
SerializableDataclass only: if true, register the class with ZANJ for loading (defaults toTrue
)on_typecheck_error : ErrorMode
SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return falseon_typecheck_mismatch : ErrorMode
SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field
188def serializable_field( 189 *_args, 190 default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 191 default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 192 init: bool = True, 193 repr: bool = True, 194 hash: Optional[bool] = None, 195 compare: bool = True, 196 metadata: Optional[types.MappingProxyType] = None, 197 kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 198 serialize: bool = True, 199 serialization_fn: Optional[Callable[[Any], Any]] = None, 200 deserialize_fn: Optional[Callable[[Any], Any]] = None, 201 assert_type: bool = True, 202 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 203 **kwargs: Any, 204) -> Any: 205 """Create a new `SerializableField` 206 207 ``` 208 default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING, 209 default_factory: Callable[[], Sfield_T] 210 | dataclasses._MISSING_TYPE = dataclasses.MISSING, 211 init: bool = True, 212 repr: bool = True, 213 hash: Optional[bool] = None, 214 compare: bool = True, 215 metadata: types.MappingProxyType | None = None, 216 kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, 217 # ---------------------------------------------------------------------- 218 # new in `SerializableField`, not in `dataclasses.Field` 219 serialize: bool = True, 220 serialization_fn: Optional[Callable[[Any], Any]] = None, 221 loading_fn: Optional[Callable[[Any], Any]] = None, 222 deserialize_fn: Optional[Callable[[Any], Any]] = None, 223 assert_type: bool = True, 224 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 225 ``` 226 227 # new Parameters: 228 - `serialize`: whether to serialize this field when serializing the class' 229 - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` 230 - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. 231 - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. 232 - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field. 233 - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking. 234 235 # Gotchas: 236 - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: 237 238 ```python 239 class MyClass: 240 my_field: int = serializable_field( 241 serialization_fn=lambda x: str(x), 242 loading_fn=lambda x["my_field"]: int(x) 243 ) 244 ``` 245 246 using `deserialize_fn` instead: 247 248 ```python 249 class MyClass: 250 my_field: int = serializable_field( 251 serialization_fn=lambda x: str(x), 252 deserialize_fn=lambda x: int(x) 253 ) 254 ``` 255 256 In the above code, `my_field` is an int but will be serialized as a string. 257 258 note that if not using ZANJ, and you have a class inside a container, you MUST provide 259 `serialization_fn` and `loading_fn` to serialize and load the container. 260 ZANJ will automatically do this for you. 261 262 # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test 263 """ 264 assert len(_args) == 0, f"unexpected positional arguments: {_args}" 265 return SerializableField( 266 default=default, 267 default_factory=default_factory, 268 init=init, 269 repr=repr, 270 hash=hash, 271 compare=compare, 272 metadata=metadata, 273 kw_only=kw_only, 274 serialize=serialize, 275 serialization_fn=serialization_fn, 276 deserialize_fn=deserialize_fn, 277 assert_type=assert_type, 278 custom_typecheck_fn=custom_typecheck_fn, 279 **kwargs, 280 )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
new Parameters:
serialize
: whether to serialize this field when serializing the class'serialization_fn
: function taking the instance of the field and returning a serializable object. If not provided, will iterate through theSerializerHandler
s defined inmuutils.json_serialize.json_serialize
loading_fn
: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.deserialize_fn
: new alternative toloading_fn
. takes only the field's value, not the whole class. if bothloading_fn
anddeserialize_fn
are provided, an error will be raised.assert_type
: whether to assert the type of the field when loading. ifFalse
, will not check the type of the field.custom_typecheck_fn
: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
Gotchas:
loading_fn
takes the dict of the class, not the field. if you wanted aloading_fn
that does nothing, you'd write:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)
using deserialize_fn
instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)
In the above code, my_field
is an int but will be serialized as a string.
note that if not using ZANJ, and you have a class inside a container, you MUST provide
serialization_fn
and loading_fn
to serialize and load the container.
ZANJ will automatically do this for you.
TODO: custom_value_check_fn
: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
49def arr_metadata(arr) -> dict[str, list[int] | str | int]: 50 """get metadata for a numpy array""" 51 return { 52 "shape": list(arr.shape), 53 "dtype": ( 54 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) 55 ), 56 "n_elements": array_n_elements(arr), 57 }
get metadata for a numpy array
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: 169 """load a json-serialized array, infer the mode if not specified""" 170 # return arr if its already a numpy array 171 if isinstance(arr, np.ndarray) and array_mode is None: 172 return arr 173 174 # try to infer the array_mode 175 array_mode_inferred: ArrayMode = infer_array_mode(arr) 176 if array_mode is None: 177 array_mode = array_mode_inferred 178 elif array_mode != array_mode_inferred: 179 warnings.warn( 180 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 181 ) 182 183 # actually load the array 184 if array_mode == "array_list_meta": 185 assert isinstance( 186 arr, typing.Mapping 187 ), f"invalid list format: {type(arr) = }\n{arr = }" 188 189 data = np.array(arr["data"], dtype=arr["dtype"]) 190 if tuple(arr["shape"]) != tuple(data.shape): 191 raise ValueError(f"invalid shape: {arr}") 192 return data 193 194 elif array_mode == "array_hex_meta": 195 assert isinstance( 196 arr, typing.Mapping 197 ), f"invalid list format: {type(arr) = }\n{arr = }" 198 199 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) 200 return data.reshape(arr["shape"]) 201 202 elif array_mode == "array_b64_meta": 203 assert isinstance( 204 arr, typing.Mapping 205 ), f"invalid list format: {type(arr) = }\n{arr = }" 206 207 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) 208 return data.reshape(arr["shape"]) 209 210 elif array_mode == "list": 211 assert isinstance( 212 arr, typing.Sequence 213 ), f"invalid list format: {type(arr) = }\n{arr = }" 214 215 return np.array(arr) 216 elif array_mode == "external": 217 # assume ZANJ has taken care of it 218 assert isinstance(arr, typing.Mapping) 219 if "data" not in arr: 220 raise KeyError( 221 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 222 ) 223 return arr["data"] 224 elif array_mode == "zero_dim": 225 assert isinstance(arr, typing.Mapping) 226 data = np.array(arr["data"]) 227 if tuple(arr["shape"]) != tuple(data.shape): 228 raise ValueError(f"invalid shape: {arr}") 229 return data 230 else: 231 raise ValueError(f"invalid array_mode: {array_mode}")
load a json-serialized array, infer the mode if not specified
234class JsonSerializer: 235 """Json serialization class (holds configs) 236 237 # Parameters: 238 - `array_mode : ArrayMode` 239 how to write arrays 240 (defaults to `"array_list_meta"`) 241 - `error_mode : ErrorMode` 242 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 243 (defaults to `"except"`) 244 - `handlers_pre : MonoTuple[SerializerHandler]` 245 handlers to use before the default handlers 246 (defaults to `tuple()`) 247 - `handlers_default : MonoTuple[SerializerHandler]` 248 default handlers to use 249 (defaults to `DEFAULT_HANDLERS`) 250 - `write_only_format : bool` 251 changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 252 (defaults to `False`) 253 254 # Raises: 255 - `ValueError`: on init, if `args` is not empty 256 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 257 258 """ 259 260 def __init__( 261 self, 262 *args, 263 array_mode: ArrayMode = "array_list_meta", 264 error_mode: ErrorMode = ErrorMode.EXCEPT, 265 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 266 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 267 write_only_format: bool = False, 268 ): 269 if len(args) > 0: 270 raise ValueError( 271 f"JsonSerializer takes no positional arguments!\n{args = }" 272 ) 273 274 self.array_mode: ArrayMode = array_mode 275 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 276 self.write_only_format: bool = write_only_format 277 # join up the handlers 278 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 279 handlers_default 280 ) 281 282 def json_serialize( 283 self, 284 obj: Any, 285 path: ObjectPath = tuple(), 286 ) -> JSONitem: 287 try: 288 for handler in self.handlers: 289 if handler.check(self, obj, path): 290 output: JSONitem = handler.serialize_func(self, obj, path) 291 if self.write_only_format: 292 if isinstance(output, dict) and "__format__" in output: 293 new_fmt: JSONitem = output.pop("__format__") 294 output["__write_format__"] = new_fmt 295 return output 296 297 raise ValueError(f"no handler found for object with {type(obj) = }") 298 299 except Exception as e: 300 if self.error_mode == "except": 301 obj_str: str = repr(obj) 302 if len(obj_str) > 1000: 303 obj_str = obj_str[:1000] + "..." 304 raise SerializationException( 305 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 306 ) from e 307 elif self.error_mode == "warn": 308 warnings.warn( 309 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 310 ) 311 312 return repr(obj) 313 314 def hashify( 315 self, 316 obj: Any, 317 path: ObjectPath = tuple(), 318 force: bool = True, 319 ) -> Hashableitem: 320 """try to turn any object into something hashable""" 321 data = self.json_serialize(obj, path=path) 322 323 # recursive hashify, turning dicts and lists into tuples 324 return _recursive_hashify(data, force=force)
Json serialization class (holds configs)
Parameters:
array_mode : ArrayMode
how to write arrays (defaults to"array_list_meta"
)error_mode : ErrorMode
what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to"except"
)handlers_pre : MonoTuple[SerializerHandler]
handlers to use before the default handlers (defaults totuple()
)handlers_default : MonoTuple[SerializerHandler]
default handlers to use (defaults toDEFAULT_HANDLERS
)write_only_format : bool
changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults toFalse
)
Raises:
ValueError
: on init, ifargs
is not emptySerializationException
: onjson_serialize()
, if any error occurs when trying to serialize an object anderror_mode
is set toErrorMode.EXCEPT"
260 def __init__( 261 self, 262 *args, 263 array_mode: ArrayMode = "array_list_meta", 264 error_mode: ErrorMode = ErrorMode.EXCEPT, 265 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 266 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 267 write_only_format: bool = False, 268 ): 269 if len(args) > 0: 270 raise ValueError( 271 f"JsonSerializer takes no positional arguments!\n{args = }" 272 ) 273 274 self.array_mode: ArrayMode = array_mode 275 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 276 self.write_only_format: bool = write_only_format 277 # join up the handlers 278 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 279 handlers_default 280 )
282 def json_serialize( 283 self, 284 obj: Any, 285 path: ObjectPath = tuple(), 286 ) -> JSONitem: 287 try: 288 for handler in self.handlers: 289 if handler.check(self, obj, path): 290 output: JSONitem = handler.serialize_func(self, obj, path) 291 if self.write_only_format: 292 if isinstance(output, dict) and "__format__" in output: 293 new_fmt: JSONitem = output.pop("__format__") 294 output["__write_format__"] = new_fmt 295 return output 296 297 raise ValueError(f"no handler found for object with {type(obj) = }") 298 299 except Exception as e: 300 if self.error_mode == "except": 301 obj_str: str = repr(obj) 302 if len(obj_str) > 1000: 303 obj_str = obj_str[:1000] + "..." 304 raise SerializationException( 305 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 306 ) from e 307 elif self.error_mode == "warn": 308 warnings.warn( 309 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 310 ) 311 312 return repr(obj)
314 def hashify( 315 self, 316 obj: Any, 317 path: ObjectPath = tuple(), 318 force: bool = True, 319 ) -> Hashableitem: 320 """try to turn any object into something hashable""" 321 data = self.json_serialize(obj, path=path) 322 323 # recursive hashify, turning dicts and lists into tuples 324 return _recursive_hashify(data, force=force)
try to turn any object into something hashable
81def try_catch(func: Callable): 82 """wraps the function to catch exceptions, returns serialized error message on exception 83 84 returned func will return normal result on success, or error message on exception 85 """ 86 87 @functools.wraps(func) 88 def newfunc(*args, **kwargs): 89 try: 90 return func(*args, **kwargs) 91 except Exception as e: 92 return f"{e.__class__.__name__}: {e}" 93 94 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
175def dc_eq( 176 dc1, 177 dc2, 178 except_when_class_mismatch: bool = False, 179 false_when_class_mismatch: bool = True, 180 except_when_field_mismatch: bool = False, 181) -> bool: 182 """ 183 checks if two dataclasses which (might) hold numpy arrays are equal 184 185 # Parameters: 186 187 - `dc1`: the first dataclass 188 - `dc2`: the second dataclass 189 - `except_when_class_mismatch: bool` 190 if `True`, will throw `TypeError` if the classes are different. 191 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 192 (default: `False`) 193 - `false_when_class_mismatch: bool` 194 only relevant if `except_when_class_mismatch` is `False`. 195 if `True`, will return `False` if the classes are different. 196 if `False`, will attempt to compare the fields. 197 - `except_when_field_mismatch: bool` 198 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 199 if `True`, will throw `TypeError` if the fields are different. 200 (default: `True`) 201 202 # Returns: 203 - `bool`: True if the dataclasses are equal, False otherwise 204 205 # Raises: 206 - `TypeError`: if the dataclasses are of different classes 207 - `AttributeError`: if the dataclasses have different fields 208 209 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 210 ``` 211 [START] 212 ▼ 213 ┌───────────┐ ┌─────────┐ 214 │dc1 is dc2?├─►│ classes │ 215 └──┬────────┘No│ match? │ 216 ──── │ ├─────────┤ 217 (True)◄──┘Yes │No │Yes 218 ──── ▼ ▼ 219 ┌────────────────┐ ┌────────────┐ 220 │ except when │ │ fields keys│ 221 │ class mismatch?│ │ match? │ 222 ├───────────┬────┘ ├───────┬────┘ 223 │Yes │No │No │Yes 224 ▼ ▼ ▼ ▼ 225 ─────────── ┌──────────┐ ┌────────┐ 226 { raise } │ except │ │ field │ 227 { TypeError } │ when │ │ values │ 228 ─────────── │ field │ │ match? │ 229 │ mismatch?│ ├────┬───┘ 230 ├───────┬──┘ │ │Yes 231 │Yes │No │No ▼ 232 ▼ ▼ │ ──── 233 ─────────────── ───── │ (True) 234 { raise } (False)◄┘ ──── 235 { AttributeError} ───── 236 ─────────────── 237 ``` 238 239 """ 240 if dc1 is dc2: 241 return True 242 243 if dc1.__class__ is not dc2.__class__: 244 if except_when_class_mismatch: 245 # if the classes don't match, raise an error 246 raise TypeError( 247 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 248 ) 249 if except_when_field_mismatch: 250 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 251 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 252 fields_match: bool = set(dc1_fields) == set(dc2_fields) 253 if not fields_match: 254 # if the fields match, keep going 255 raise AttributeError( 256 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 257 ) 258 return False 259 260 return all( 261 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 262 for fld in dataclasses.fields(dc1) 263 if fld.compare 264 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
ifTrue
, will throwTypeError
if the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatch
isFalse
(default:False
)false_when_class_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
. ifTrue
, will returnFalse
if the classes are different. ifFalse
, will attempt to compare the fields.except_when_field_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
andfalse_when_class_mismatch
isFalse
. ifTrue
, will throwTypeError
if the fields are different. (default:True
)
Returns:
bool
: True if the dataclasses are equal, False otherwise
Raises:
TypeError
: if the dataclasses are of different classesAttributeError
: if the dataclasses have different fields
TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
[START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
317@dataclass_transform( 318 field_specifiers=(serializable_field, SerializableField), 319) 320class SerializableDataclass(abc.ABC): 321 """Base class for serializable dataclasses 322 323 only for linting and type checking, still need to call `serializable_dataclass` decorator 324 325 # Usage: 326 327 ```python 328 @serializable_dataclass 329 class MyClass(SerializableDataclass): 330 a: int 331 b: str 332 ``` 333 334 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 335 336 >>> my_obj = MyClass(a=1, b="q") 337 >>> s = json.dumps(my_obj.serialize()) 338 >>> s 339 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 340 >>> read_obj = MyClass.load(json.loads(s)) 341 >>> read_obj == my_obj 342 True 343 344 This isn't too impressive on its own, but it gets more useful when you have nested classses, 345 or fields that are not json-serializable by default: 346 347 ```python 348 @serializable_dataclass 349 class NestedClass(SerializableDataclass): 350 x: str 351 y: MyClass 352 act_fun: torch.nn.Module = serializable_field( 353 default=torch.nn.ReLU(), 354 serialization_fn=lambda x: str(x), 355 deserialize_fn=lambda x: getattr(torch.nn, x)(), 356 ) 357 ``` 358 359 which gives us: 360 361 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 362 >>> s = json.dumps(nc.serialize()) 363 >>> s 364 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 365 >>> read_nc = NestedClass.load(json.loads(s)) 366 >>> read_nc == nc 367 True 368 """ 369 370 def serialize(self) -> dict[str, Any]: 371 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 372 raise NotImplementedError( 373 f"decorate {self.__class__ = } with `@serializable_dataclass`" 374 ) 375 376 @classmethod 377 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 378 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 379 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 380 381 def validate_fields_types( 382 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 383 ) -> bool: 384 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 385 return SerializableDataclass__validate_fields_types( 386 self, on_typecheck_error=on_typecheck_error 387 ) 388 389 def validate_field_type( 390 self, 391 field: "SerializableField|str", 392 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 393 ) -> bool: 394 """given a dataclass, check the field matches the type hint""" 395 return SerializableDataclass__validate_field_type( 396 self, field, on_typecheck_error=on_typecheck_error 397 ) 398 399 def __eq__(self, other: Any) -> bool: 400 return dc_eq(self, other) 401 402 def __hash__(self) -> int: 403 "hashes the json-serialized representation of the class" 404 return hash(json.dumps(self.serialize())) 405 406 def diff( 407 self, other: "SerializableDataclass", of_serialized: bool = False 408 ) -> dict[str, Any]: 409 """get a rich and recursive diff between two instances of a serializable dataclass 410 411 ```python 412 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 413 {'b': {'self': 2, 'other': 3}} 414 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 415 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 416 ``` 417 418 # Parameters: 419 - `other : SerializableDataclass` 420 other instance to compare against 421 - `of_serialized : bool` 422 if true, compare serialized data and not raw values 423 (defaults to `False`) 424 425 # Returns: 426 - `dict[str, Any]` 427 428 429 # Raises: 430 - `ValueError` : if the instances are not of the same type 431 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 432 """ 433 # match types 434 if type(self) is not type(other): 435 raise ValueError( 436 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 437 ) 438 439 # initialize the diff result 440 diff_result: dict = {} 441 442 # if they are the same, return the empty diff 443 if self == other: 444 return diff_result 445 446 # if we are working with serialized data, serialize the instances 447 if of_serialized: 448 ser_self: dict = self.serialize() 449 ser_other: dict = other.serialize() 450 451 # for each field in the class 452 for field in dataclasses.fields(self): # type: ignore[arg-type] 453 # skip fields that are not for comparison 454 if not field.compare: 455 continue 456 457 # get values 458 field_name: str = field.name 459 self_value = getattr(self, field_name) 460 other_value = getattr(other, field_name) 461 462 # if the values are both serializable dataclasses, recurse 463 if isinstance(self_value, SerializableDataclass) and isinstance( 464 other_value, SerializableDataclass 465 ): 466 nested_diff: dict = self_value.diff( 467 other_value, of_serialized=of_serialized 468 ) 469 if nested_diff: 470 diff_result[field_name] = nested_diff 471 # only support serializable dataclasses 472 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 473 other_value 474 ): 475 raise ValueError("Non-serializable dataclass is not supported") 476 else: 477 # get the values of either the serialized or the actual values 478 self_value_s = ser_self[field_name] if of_serialized else self_value 479 other_value_s = ser_other[field_name] if of_serialized else other_value 480 # compare the values 481 if not array_safe_eq(self_value_s, other_value_s): 482 diff_result[field_name] = {"self": self_value, "other": other_value} 483 484 # return the diff result 485 return diff_result 486 487 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 488 """update the instance from a nested dict, useful for configuration from command line args 489 490 # Parameters: 491 - `nested_dict : dict[str, Any]` 492 nested dict to update the instance with 493 """ 494 for field in dataclasses.fields(self): # type: ignore[arg-type] 495 field_name: str = field.name 496 self_value = getattr(self, field_name) 497 498 if field_name in nested_dict: 499 if isinstance(self_value, SerializableDataclass): 500 self_value.update_from_nested_dict(nested_dict[field_name]) 501 else: 502 setattr(self, field_name, nested_dict[field_name]) 503 504 def __copy__(self) -> "SerializableDataclass": 505 "deep copy by serializing and loading the instance to json" 506 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 507 508 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 509 "deep copy by serializing and loading the instance to json" 510 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
370 def serialize(self) -> dict[str, Any]: 371 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 372 raise NotImplementedError( 373 f"decorate {self.__class__ = } with `@serializable_dataclass`" 374 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
376 @classmethod 377 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 378 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 379 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
381 def validate_fields_types( 382 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 383 ) -> bool: 384 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 385 return SerializableDataclass__validate_fields_types( 386 self, on_typecheck_error=on_typecheck_error 387 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
389 def validate_field_type( 390 self, 391 field: "SerializableField|str", 392 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 393 ) -> bool: 394 """given a dataclass, check the field matches the type hint""" 395 return SerializableDataclass__validate_field_type( 396 self, field, on_typecheck_error=on_typecheck_error 397 )
given a dataclass, check the field matches the type hint
406 def diff( 407 self, other: "SerializableDataclass", of_serialized: bool = False 408 ) -> dict[str, Any]: 409 """get a rich and recursive diff between two instances of a serializable dataclass 410 411 ```python 412 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 413 {'b': {'self': 2, 'other': 3}} 414 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 415 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 416 ``` 417 418 # Parameters: 419 - `other : SerializableDataclass` 420 other instance to compare against 421 - `of_serialized : bool` 422 if true, compare serialized data and not raw values 423 (defaults to `False`) 424 425 # Returns: 426 - `dict[str, Any]` 427 428 429 # Raises: 430 - `ValueError` : if the instances are not of the same type 431 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 432 """ 433 # match types 434 if type(self) is not type(other): 435 raise ValueError( 436 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 437 ) 438 439 # initialize the diff result 440 diff_result: dict = {} 441 442 # if they are the same, return the empty diff 443 if self == other: 444 return diff_result 445 446 # if we are working with serialized data, serialize the instances 447 if of_serialized: 448 ser_self: dict = self.serialize() 449 ser_other: dict = other.serialize() 450 451 # for each field in the class 452 for field in dataclasses.fields(self): # type: ignore[arg-type] 453 # skip fields that are not for comparison 454 if not field.compare: 455 continue 456 457 # get values 458 field_name: str = field.name 459 self_value = getattr(self, field_name) 460 other_value = getattr(other, field_name) 461 462 # if the values are both serializable dataclasses, recurse 463 if isinstance(self_value, SerializableDataclass) and isinstance( 464 other_value, SerializableDataclass 465 ): 466 nested_diff: dict = self_value.diff( 467 other_value, of_serialized=of_serialized 468 ) 469 if nested_diff: 470 diff_result[field_name] = nested_diff 471 # only support serializable dataclasses 472 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 473 other_value 474 ): 475 raise ValueError("Non-serializable dataclass is not supported") 476 else: 477 # get the values of either the serialized or the actual values 478 self_value_s = ser_self[field_name] if of_serialized else self_value 479 other_value_s = ser_other[field_name] if of_serialized else other_value 480 # compare the values 481 if not array_safe_eq(self_value_s, other_value_s): 482 diff_result[field_name] = {"self": self_value, "other": other_value} 483 484 # return the diff result 485 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
487 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 488 """update the instance from a nested dict, useful for configuration from command line args 489 490 # Parameters: 491 - `nested_dict : dict[str, Any]` 492 nested dict to update the instance with 493 """ 494 for field in dataclasses.fields(self): # type: ignore[arg-type] 495 field_name: str = field.name 496 self_value = getattr(self, field_name) 497 498 if field_name in nested_dict: 499 if isinstance(self_value, SerializableDataclass): 500 self_value.update_from_nested_dict(nested_dict[field_name]) 501 else: 502 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with