muutils.json_serialize.serializable_dataclass
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj)
will give you a dict, but if some fields are not json-serializable,
you will get an error when you call json.dumps(d)
. This module provides a way around that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
1"""save and load objects to and from json or compatible formats in a recoverable way 2 3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 5 6Instead, you define your class: 7 8```python 9@serializable_dataclass 10class MyClass(SerializableDataclass): 11 a: int 12 b: str 13``` 14 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 16 17 >>> my_obj = MyClass(a=1, b="q") 18 >>> s = json.dumps(my_obj.serialize()) 19 >>> s 20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 21 >>> read_obj = MyClass.load(json.loads(s)) 22 >>> read_obj == my_obj 23 True 24 25This isn't too impressive on its own, but it gets more useful when you have nested classses, 26or fields that are not json-serializable by default: 27 28```python 29@serializable_dataclass 30class NestedClass(SerializableDataclass): 31 x: str 32 y: MyClass 33 act_fun: torch.nn.Module = serializable_field( 34 default=torch.nn.ReLU(), 35 serialization_fn=lambda x: str(x), 36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 37 ) 38``` 39 40which gives us: 41 42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 43 >>> s = json.dumps(nc.serialize()) 44 >>> s 45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 46 >>> read_nc = NestedClass.load(json.loads(s)) 47 >>> read_nc == nc 48 True 49 50""" 51 52from __future__ import annotations 53 54import abc 55import dataclasses 56import functools 57import json 58import sys 59import typing 60import warnings 61from typing import Any, Optional, Type, TypeVar 62 63from muutils.errormode import ErrorMode 64from muutils.validate_type import validate_type 65from muutils.json_serialize.serializable_field import ( 66 SerializableField, 67 serializable_field, 68) 69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq 70 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 72 73 74def _dataclass_transform_mock( 75 *, 76 eq_default: bool = True, 77 order_default: bool = False, 78 kw_only_default: bool = False, 79 frozen_default: bool = False, 80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (), 81 **kwargs: Any, 82) -> typing.Callable: 83 "mock `typing.dataclass_transform` for python <3.11" 84 85 def decorator(cls_or_fn): 86 cls_or_fn.__dataclass_transform__ = { 87 "eq_default": eq_default, 88 "order_default": order_default, 89 "kw_only_default": kw_only_default, 90 "frozen_default": frozen_default, 91 "field_specifiers": field_specifiers, 92 "kwargs": kwargs, 93 } 94 return cls_or_fn 95 96 return decorator 97 98 99if sys.version_info < (3, 11): 100 dataclass_transform = _dataclass_transform_mock 101else: 102 dataclass_transform = typing.dataclass_transform 103 104 105T = TypeVar("T") 106 107 108class CantGetTypeHintsWarning(UserWarning): 109 "special warning for when we can't get type hints" 110 111 pass 112 113 114class ZanjMissingWarning(UserWarning): 115 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 116 117 pass 118 119 120_zanj_loading_needs_import: bool = True 121"flag to keep track of if we have successfully imported ZANJ" 122 123 124def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 125 """Register a serializable dataclass with the ZANJ import 126 127 this allows `ZANJ().read()` to load the class and not just return plain dicts 128 129 130 # TODO: there is some duplication here with register_loader_handler 131 """ 132 global _zanj_loading_needs_import 133 134 if _zanj_loading_needs_import: 135 try: 136 from zanj.loading import ( # type: ignore[import] 137 LoaderHandler, 138 register_loader_handler, 139 ) 140 except ImportError: 141 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 142 # warnings.warn( 143 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 144 # ZanjMissingWarning, 145 # ) 146 return 147 148 _format: str = f"{cls.__name__}(SerializableDataclass)" 149 lh: LoaderHandler = LoaderHandler( 150 check=lambda json_item, path=None, z=None: ( # type: ignore 151 isinstance(json_item, dict) 152 and _FORMAT_KEY in json_item 153 and json_item[_FORMAT_KEY].startswith(_format) 154 ), 155 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 156 uid=_format, 157 source_pckg=cls.__module__, 158 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 159 ) 160 161 register_loader_handler(lh) 162 163 return lh 164 165 166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 168 169 170class FieldIsNotInitOrSerializeWarning(UserWarning): 171 pass 172 173 174def SerializableDataclass__validate_field_type( 175 self: SerializableDataclass, 176 field: SerializableField | str, 177 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 178) -> bool: 179 """given a dataclass, check the field matches the type hint 180 181 this function is written to `SerializableDataclass.validate_field_type` 182 183 # Parameters: 184 - `self : SerializableDataclass` 185 `SerializableDataclass` instance 186 - `field : SerializableField | str` 187 field to validate, will get from `self.__dataclass_fields__` if an `str` 188 - `on_typecheck_error : ErrorMode` 189 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 190 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 191 192 # Returns: 193 - `bool` 194 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 195 """ 196 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 197 198 # get field 199 _field: SerializableField 200 if isinstance(field, str): 201 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 202 else: 203 _field = field 204 205 # do nothing case 206 if not _field.assert_type: 207 return True 208 209 # if field is not `init` or not `serialize`, skip but warn 210 # TODO: how to handle fields which are not `init` or `serialize`? 211 if not _field.init or not _field.serialize: 212 warnings.warn( 213 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 214 FieldIsNotInitOrSerializeWarning, 215 ) 216 return True 217 218 assert isinstance( 219 _field, SerializableField 220 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 221 222 # get field type hints 223 try: 224 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 225 except KeyError as e: 226 on_typecheck_error.process( 227 ( 228 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 229 + f"{get_cls_type_hints(self.__class__) = }\n" 230 + f"Python version is {sys.version_info = }. You can:\n" 231 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 232 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 233 + " - use python 3.9.x or higher\n" 234 + " - specify custom type validation function via `custom_typecheck_fn`\n" 235 ), 236 except_cls=TypeError, 237 except_from=e, 238 ) 239 return False 240 241 # get the value 242 value: Any = getattr(self, _field.name) 243 244 # validate the type 245 try: 246 type_is_valid: bool 247 # validate the type with the default type validator 248 if _field.custom_typecheck_fn is None: 249 type_is_valid = validate_type(value, field_type_hint) 250 # validate the type with a custom type validator 251 else: 252 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 253 254 return type_is_valid 255 256 except Exception as e: 257 on_typecheck_error.process( 258 "exception while validating type: " 259 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 260 except_cls=ValueError, 261 except_from=e, 262 ) 263 return False 264 265 266def SerializableDataclass__validate_fields_types__dict( 267 self: SerializableDataclass, 268 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 269) -> dict[str, bool]: 270 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 271 272 returns a dict of field names to bools, where the bool is if the field type is valid 273 """ 274 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 275 276 # if except, bundle the exceptions 277 results: dict[str, bool] = dict() 278 exceptions: dict[str, Exception] = dict() 279 280 # for each field in the class 281 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 282 for field in cls_fields: 283 try: 284 results[field.name] = self.validate_field_type(field, on_typecheck_error) 285 except Exception as e: 286 results[field.name] = False 287 exceptions[field.name] = e 288 289 # figure out what to do with the exceptions 290 if len(exceptions) > 0: 291 on_typecheck_error.process( 292 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 293 + "\n\t" 294 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 295 except_cls=ValueError, 296 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 297 except_from=list(exceptions.values())[0], 298 ) 299 300 return results 301 302 303def SerializableDataclass__validate_fields_types( 304 self: SerializableDataclass, 305 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 306) -> bool: 307 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 308 return all( 309 SerializableDataclass__validate_fields_types__dict( 310 self, on_typecheck_error=on_typecheck_error 311 ).values() 312 ) 313 314 315@dataclass_transform( 316 field_specifiers=(serializable_field, SerializableField), 317) 318class SerializableDataclass(abc.ABC): 319 """Base class for serializable dataclasses 320 321 only for linting and type checking, still need to call `serializable_dataclass` decorator 322 323 # Usage: 324 325 ```python 326 @serializable_dataclass 327 class MyClass(SerializableDataclass): 328 a: int 329 b: str 330 ``` 331 332 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 333 334 >>> my_obj = MyClass(a=1, b="q") 335 >>> s = json.dumps(my_obj.serialize()) 336 >>> s 337 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 338 >>> read_obj = MyClass.load(json.loads(s)) 339 >>> read_obj == my_obj 340 True 341 342 This isn't too impressive on its own, but it gets more useful when you have nested classses, 343 or fields that are not json-serializable by default: 344 345 ```python 346 @serializable_dataclass 347 class NestedClass(SerializableDataclass): 348 x: str 349 y: MyClass 350 act_fun: torch.nn.Module = serializable_field( 351 default=torch.nn.ReLU(), 352 serialization_fn=lambda x: str(x), 353 deserialize_fn=lambda x: getattr(torch.nn, x)(), 354 ) 355 ``` 356 357 which gives us: 358 359 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 360 >>> s = json.dumps(nc.serialize()) 361 >>> s 362 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 363 >>> read_nc = NestedClass.load(json.loads(s)) 364 >>> read_nc == nc 365 True 366 """ 367 368 def serialize(self) -> dict[str, Any]: 369 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 370 raise NotImplementedError( 371 f"decorate {self.__class__ = } with `@serializable_dataclass`" 372 ) 373 374 @classmethod 375 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 376 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 377 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 378 379 def validate_fields_types( 380 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 381 ) -> bool: 382 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 383 return SerializableDataclass__validate_fields_types( 384 self, on_typecheck_error=on_typecheck_error 385 ) 386 387 def validate_field_type( 388 self, 389 field: "SerializableField|str", 390 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 391 ) -> bool: 392 """given a dataclass, check the field matches the type hint""" 393 return SerializableDataclass__validate_field_type( 394 self, field, on_typecheck_error=on_typecheck_error 395 ) 396 397 def __eq__(self, other: Any) -> bool: 398 return dc_eq(self, other) 399 400 def __hash__(self) -> int: 401 "hashes the json-serialized representation of the class" 402 return hash(json.dumps(self.serialize())) 403 404 def diff( 405 self, other: "SerializableDataclass", of_serialized: bool = False 406 ) -> dict[str, Any]: 407 """get a rich and recursive diff between two instances of a serializable dataclass 408 409 ```python 410 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 411 {'b': {'self': 2, 'other': 3}} 412 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 413 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 414 ``` 415 416 # Parameters: 417 - `other : SerializableDataclass` 418 other instance to compare against 419 - `of_serialized : bool` 420 if true, compare serialized data and not raw values 421 (defaults to `False`) 422 423 # Returns: 424 - `dict[str, Any]` 425 426 427 # Raises: 428 - `ValueError` : if the instances are not of the same type 429 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 430 """ 431 # match types 432 if type(self) is not type(other): 433 raise ValueError( 434 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 435 ) 436 437 # initialize the diff result 438 diff_result: dict = {} 439 440 # if they are the same, return the empty diff 441 try: 442 if self == other: 443 return diff_result 444 except Exception: 445 pass 446 447 # if we are working with serialized data, serialize the instances 448 if of_serialized: 449 ser_self: dict = self.serialize() 450 ser_other: dict = other.serialize() 451 452 # for each field in the class 453 for field in dataclasses.fields(self): # type: ignore[arg-type] 454 # skip fields that are not for comparison 455 if not field.compare: 456 continue 457 458 # get values 459 field_name: str = field.name 460 self_value = getattr(self, field_name) 461 other_value = getattr(other, field_name) 462 463 # if the values are both serializable dataclasses, recurse 464 if isinstance(self_value, SerializableDataclass) and isinstance( 465 other_value, SerializableDataclass 466 ): 467 nested_diff: dict = self_value.diff( 468 other_value, of_serialized=of_serialized 469 ) 470 if nested_diff: 471 diff_result[field_name] = nested_diff 472 # only support serializable dataclasses 473 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 474 other_value 475 ): 476 raise ValueError("Non-serializable dataclass is not supported") 477 else: 478 # get the values of either the serialized or the actual values 479 self_value_s = ser_self[field_name] if of_serialized else self_value 480 other_value_s = ser_other[field_name] if of_serialized else other_value 481 # compare the values 482 if not array_safe_eq(self_value_s, other_value_s): 483 diff_result[field_name] = {"self": self_value, "other": other_value} 484 485 # return the diff result 486 return diff_result 487 488 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 489 """update the instance from a nested dict, useful for configuration from command line args 490 491 # Parameters: 492 - `nested_dict : dict[str, Any]` 493 nested dict to update the instance with 494 """ 495 for field in dataclasses.fields(self): # type: ignore[arg-type] 496 field_name: str = field.name 497 self_value = getattr(self, field_name) 498 499 if field_name in nested_dict: 500 if isinstance(self_value, SerializableDataclass): 501 self_value.update_from_nested_dict(nested_dict[field_name]) 502 else: 503 setattr(self, field_name, nested_dict[field_name]) 504 505 def __copy__(self) -> "SerializableDataclass": 506 "deep copy by serializing and loading the instance to json" 507 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 508 509 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 510 "deep copy by serializing and loading the instance to json" 511 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 512 513 514# cache this so we don't have to keep getting it 515# TODO: are the types hashable? does this even make sense? 516@functools.lru_cache(typed=True) 517def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 518 "cached typing.get_type_hints for a class" 519 return typing.get_type_hints(cls) 520 521 522def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 523 "helper function to get type hints for a class" 524 cls_type_hints: dict[str, Any] 525 try: 526 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 527 if len(cls_type_hints) == 0: 528 cls_type_hints = typing.get_type_hints(cls) 529 530 if len(cls_type_hints) == 0: 531 raise ValueError(f"empty type hints for {cls.__name__ = }") 532 except (TypeError, NameError, ValueError) as e: 533 raise TypeError( 534 f"Cannot get type hints for {cls = }\n" 535 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 536 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 537 + f" {e = }" 538 ) from e 539 540 return cls_type_hints 541 542 543class KWOnlyError(NotImplementedError): 544 "kw-only dataclasses are not supported in python <3.9" 545 546 pass 547 548 549class FieldError(ValueError): 550 "base class for field errors" 551 552 pass 553 554 555class NotSerializableFieldException(FieldError): 556 "field is not a `SerializableField`" 557 558 pass 559 560 561class FieldSerializationError(FieldError): 562 "error while serializing a field" 563 564 pass 565 566 567class FieldLoadingError(FieldError): 568 "error while loading a field" 569 570 pass 571 572 573class FieldTypeMismatchError(FieldError, TypeError): 574 "error when a field type does not match the type hint" 575 576 pass 577 578 579@dataclass_transform( 580 field_specifiers=(serializable_field, SerializableField), 581) 582def serializable_dataclass( 583 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 584 _cls=None, # type: ignore 585 *, 586 init: bool = True, 587 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 588 eq: bool = True, 589 order: bool = False, 590 unsafe_hash: bool = False, 591 frozen: bool = False, 592 properties_to_serialize: Optional[list[str]] = None, 593 register_handler: bool = True, 594 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 595 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 596 methods_no_override: list[str] | None = None, 597 **kwargs, 598): 599 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 600 601 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 602 603 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 604 605 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 606 607 Examines PEP 526 `__annotations__` to determine fields. 608 609 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 610 611 ```python 612 @serializable_dataclass(kw_only=True) 613 class Myclass(SerializableDataclass): 614 a: int 615 b: str 616 ``` 617 ```python 618 >>> Myclass(a=1, b="q").serialize() 619 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 620 ``` 621 622 # Parameters: 623 624 - `_cls : _type_` 625 class to decorate. don't pass this arg, just use this as a decorator 626 (defaults to `None`) 627 - `init : bool` 628 whether to add an `__init__` method 629 *(passed to dataclasses.dataclass)* 630 (defaults to `True`) 631 - `repr : bool` 632 whether to add a `__repr__` method 633 *(passed to dataclasses.dataclass)* 634 (defaults to `True`) 635 - `order : bool` 636 whether to add rich comparison methods 637 *(passed to dataclasses.dataclass)* 638 (defaults to `False`) 639 - `unsafe_hash : bool` 640 whether to add a `__hash__` method 641 *(passed to dataclasses.dataclass)* 642 (defaults to `False`) 643 - `frozen : bool` 644 whether to make the class frozen 645 *(passed to dataclasses.dataclass)* 646 (defaults to `False`) 647 - `properties_to_serialize : Optional[list[str]]` 648 which properties to add to the serialized data dict 649 **SerializableDataclass only** 650 (defaults to `None`) 651 - `register_handler : bool` 652 if true, register the class with ZANJ for loading 653 **SerializableDataclass only** 654 (defaults to `True`) 655 - `on_typecheck_error : ErrorMode` 656 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 657 **SerializableDataclass only** 658 - `on_typecheck_mismatch : ErrorMode` 659 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 660 **SerializableDataclass only** 661 - `methods_no_override : list[str]|None` 662 list of methods that should not be overridden by the decorator 663 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 664 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 665 **SerializableDataclass only** 666 (defaults to `None`) 667 - `**kwargs` 668 *(passed to dataclasses.dataclass)* 669 670 # Returns: 671 672 - `_type_` 673 the decorated class 674 675 # Raises: 676 677 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 678 - `NotSerializableFieldException` : if a field is not a `SerializableField` 679 - `FieldSerializationError` : if there is an error serializing a field 680 - `AttributeError` : if a property is not found on the class 681 - `FieldLoadingError` : if there is an error loading a field 682 """ 683 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 684 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 685 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 686 687 if properties_to_serialize is None: 688 _properties_to_serialize: list = list() 689 else: 690 _properties_to_serialize = properties_to_serialize 691 692 def wrap(cls: Type[T]) -> Type[T]: 693 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 694 for field_name, field_type in cls.__annotations__.items(): 695 field_value = getattr(cls, field_name, None) 696 if not isinstance(field_value, SerializableField): 697 if isinstance(field_value, dataclasses.Field): 698 # Convert the field to a SerializableField while preserving properties 699 field_value = SerializableField.from_Field(field_value) 700 else: 701 # Create a new SerializableField 702 field_value = serializable_field() 703 setattr(cls, field_name, field_value) 704 705 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 706 if sys.version_info < (3, 10): 707 if "kw_only" in kwargs: 708 if kwargs["kw_only"] == True: # noqa: E712 709 raise KWOnlyError( 710 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 711 ) 712 else: 713 del kwargs["kw_only"] 714 715 # call `dataclasses.dataclass` to set some stuff up 716 cls = dataclasses.dataclass( # type: ignore[call-overload] 717 cls, 718 init=init, 719 repr=repr, 720 eq=eq, 721 order=order, 722 unsafe_hash=unsafe_hash, 723 frozen=frozen, 724 **kwargs, 725 ) 726 727 # copy these to the class 728 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 729 730 # ====================================================================== 731 # define `serialize` func 732 # done locally since it depends on args to the decorator 733 # ====================================================================== 734 def serialize(self) -> dict[str, Any]: 735 result: dict[str, Any] = { 736 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 737 } 738 # for each field in the class 739 for field in dataclasses.fields(self): # type: ignore[arg-type] 740 # need it to be our special SerializableField 741 if not isinstance(field, SerializableField): 742 raise NotSerializableFieldException( 743 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 744 f"but a {type(field)} " 745 "this state should be inaccessible, please report this bug!" 746 ) 747 748 # try to save it 749 if field.serialize: 750 try: 751 # get the val 752 value = getattr(self, field.name) 753 # if it is a serializable dataclass, serialize it 754 if isinstance(value, SerializableDataclass): 755 value = value.serialize() 756 # if the value has a serialization function, use that 757 if hasattr(value, "serialize") and callable(value.serialize): 758 value = value.serialize() 759 # if the field has a serialization function, use that 760 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 761 elif field.serialization_fn: 762 value = field.serialization_fn(value) 763 764 # store the value in the result 765 result[field.name] = value 766 except Exception as e: 767 raise FieldSerializationError( 768 "\n".join( 769 [ 770 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 771 f"{field = }", 772 f"{value = }", 773 f"{self = }", 774 ] 775 ) 776 ) from e 777 778 # store each property if we can get it 779 for prop in self._properties_to_serialize: 780 if hasattr(cls, prop): 781 value = getattr(self, prop) 782 result[prop] = value 783 else: 784 raise AttributeError( 785 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 786 + f"but it is in {self._properties_to_serialize = }" 787 + f"\n{self = }" 788 ) 789 790 return result 791 792 # ====================================================================== 793 # define `load` func 794 # done locally since it depends on args to the decorator 795 # ====================================================================== 796 # mypy thinks this isnt a classmethod 797 @classmethod # type: ignore[misc] 798 def load(cls, data: dict[str, Any] | T) -> Type[T]: 799 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 800 if isinstance(data, cls): 801 return data 802 803 assert isinstance( 804 data, typing.Mapping 805 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 806 807 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 808 809 # initialize dict for keeping what we will pass to the constructor 810 ctor_kwargs: dict[str, Any] = dict() 811 812 # iterate over the fields of the class 813 for field in dataclasses.fields(cls): 814 # check if the field is a SerializableField 815 assert isinstance( 816 field, SerializableField 817 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 818 819 # check if the field is in the data and if it should be initialized 820 if (field.name in data) and field.init: 821 # get the value, we will be processing it 822 value: Any = data[field.name] 823 824 # get the type hint for the field 825 field_type_hint: Any = cls_type_hints.get(field.name, None) 826 827 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 828 if field.deserialize_fn: 829 # if it has a deserialization function, use that 830 value = field.deserialize_fn(value) 831 elif field.loading_fn: 832 # if it has a loading function, use that 833 value = field.loading_fn(data) 834 elif ( 835 field_type_hint is not None 836 and hasattr(field_type_hint, "load") 837 and callable(field_type_hint.load) 838 ): 839 # if no loading function but has a type hint with a load method, use that 840 if isinstance(value, dict): 841 value = field_type_hint.load(value) 842 else: 843 raise FieldLoadingError( 844 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 845 ) 846 else: 847 # assume no loading needs to happen, keep `value` as-is 848 pass 849 850 # store the value in the constructor kwargs 851 ctor_kwargs[field.name] = value 852 853 # create a new instance of the class with the constructor kwargs 854 output: cls = cls(**ctor_kwargs) 855 856 # validate the types of the fields if needed 857 if on_typecheck_mismatch != ErrorMode.IGNORE: 858 fields_valid: dict[str, bool] = ( 859 SerializableDataclass__validate_fields_types__dict( 860 output, 861 on_typecheck_error=on_typecheck_error, 862 ) 863 ) 864 865 # if there are any fields that are not valid, raise an error 866 if not all(fields_valid.values()): 867 msg: str = ( 868 f"Type mismatch in fields of {cls.__name__}:\n" 869 + "\n".join( 870 [ 871 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 872 for k, v in fields_valid.items() 873 if not v 874 ] 875 ) 876 ) 877 878 on_typecheck_mismatch.process( 879 msg, except_cls=FieldTypeMismatchError 880 ) 881 882 # return the new instance 883 return output 884 885 _methods_no_override: set[str] 886 if methods_no_override is None: 887 _methods_no_override = set() 888 else: 889 _methods_no_override = set(methods_no_override) 890 891 if _methods_no_override - { 892 "__eq__", 893 "serialize", 894 "load", 895 "validate_fields_types", 896 }: 897 warnings.warn( 898 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 899 ) 900 901 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 902 if "serialize" not in _methods_no_override: 903 # type is `Callable[[T], dict]` 904 cls.serialize = serialize # type: ignore[attr-defined] 905 if "load" not in _methods_no_override: 906 # type is `Callable[[dict], T]` 907 cls.load = load # type: ignore[attr-defined] 908 909 if "validate_field_type" not in _methods_no_override: 910 # type is `Callable[[T, ErrorMode], bool]` 911 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 912 913 if "__eq__" not in _methods_no_override: 914 # type is `Callable[[T, T], bool]` 915 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 916 917 # Register the class with ZANJ 918 if register_handler: 919 zanj_register_loader_serializable_dataclass(cls) 920 921 return cls 922 923 if _cls is None: 924 return wrap 925 else: 926 return wrap(_cls)
109class CantGetTypeHintsWarning(UserWarning): 110 "special warning for when we can't get type hints" 111 112 pass
special warning for when we can't get type hints
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
115class ZanjMissingWarning(UserWarning): 116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 117 118 pass
special warning for when ZANJ
is missing -- register_loader_serializable_dataclass
will not work
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 126 """Register a serializable dataclass with the ZANJ import 127 128 this allows `ZANJ().read()` to load the class and not just return plain dicts 129 130 131 # TODO: there is some duplication here with register_loader_handler 132 """ 133 global _zanj_loading_needs_import 134 135 if _zanj_loading_needs_import: 136 try: 137 from zanj.loading import ( # type: ignore[import] 138 LoaderHandler, 139 register_loader_handler, 140 ) 141 except ImportError: 142 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 143 # warnings.warn( 144 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 145 # ZanjMissingWarning, 146 # ) 147 return 148 149 _format: str = f"{cls.__name__}(SerializableDataclass)" 150 lh: LoaderHandler = LoaderHandler( 151 check=lambda json_item, path=None, z=None: ( # type: ignore 152 isinstance(json_item, dict) 153 and _FORMAT_KEY in json_item 154 and json_item[_FORMAT_KEY].startswith(_format) 155 ), 156 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 157 uid=_format, 158 source_pckg=cls.__module__, 159 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 160 ) 161 162 register_loader_handler(lh) 163 164 return lh
Register a serializable dataclass with the ZANJ import
this allows ZANJ().read()
to load the class and not just return plain dicts
TODO: there is some duplication here with register_loader_handler
Base class for warnings generated by user code.
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
175def SerializableDataclass__validate_field_type( 176 self: SerializableDataclass, 177 field: SerializableField | str, 178 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 179) -> bool: 180 """given a dataclass, check the field matches the type hint 181 182 this function is written to `SerializableDataclass.validate_field_type` 183 184 # Parameters: 185 - `self : SerializableDataclass` 186 `SerializableDataclass` instance 187 - `field : SerializableField | str` 188 field to validate, will get from `self.__dataclass_fields__` if an `str` 189 - `on_typecheck_error : ErrorMode` 190 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 191 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 192 193 # Returns: 194 - `bool` 195 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 196 """ 197 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 198 199 # get field 200 _field: SerializableField 201 if isinstance(field, str): 202 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 203 else: 204 _field = field 205 206 # do nothing case 207 if not _field.assert_type: 208 return True 209 210 # if field is not `init` or not `serialize`, skip but warn 211 # TODO: how to handle fields which are not `init` or `serialize`? 212 if not _field.init or not _field.serialize: 213 warnings.warn( 214 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 215 FieldIsNotInitOrSerializeWarning, 216 ) 217 return True 218 219 assert isinstance( 220 _field, SerializableField 221 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 222 223 # get field type hints 224 try: 225 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 226 except KeyError as e: 227 on_typecheck_error.process( 228 ( 229 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 230 + f"{get_cls_type_hints(self.__class__) = }\n" 231 + f"Python version is {sys.version_info = }. You can:\n" 232 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 233 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 234 + " - use python 3.9.x or higher\n" 235 + " - specify custom type validation function via `custom_typecheck_fn`\n" 236 ), 237 except_cls=TypeError, 238 except_from=e, 239 ) 240 return False 241 242 # get the value 243 value: Any = getattr(self, _field.name) 244 245 # validate the type 246 try: 247 type_is_valid: bool 248 # validate the type with the default type validator 249 if _field.custom_typecheck_fn is None: 250 type_is_valid = validate_type(value, field_type_hint) 251 # validate the type with a custom type validator 252 else: 253 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 254 255 return type_is_valid 256 257 except Exception as e: 258 on_typecheck_error.process( 259 "exception while validating type: " 260 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 261 except_cls=ValueError, 262 except_from=e, 263 ) 264 return False
given a dataclass, check the field matches the type hint
this function is written to SerializableDataclass.validate_field_type
Parameters:
self : SerializableDataclass
SerializableDataclass
instancefield : SerializableField | str
field to validate, will get fromself.__dataclass_fields__
if anstr
on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, the function will returnFalse
(defaults to_DEFAULT_ON_TYPECHECK_ERROR
)
Returns:
bool
if the field type is correct.False
if the field type is incorrect or an exception is thrown andon_typecheck_error
isignore
267def SerializableDataclass__validate_fields_types__dict( 268 self: SerializableDataclass, 269 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 270) -> dict[str, bool]: 271 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 272 273 returns a dict of field names to bools, where the bool is if the field type is valid 274 """ 275 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 276 277 # if except, bundle the exceptions 278 results: dict[str, bool] = dict() 279 exceptions: dict[str, Exception] = dict() 280 281 # for each field in the class 282 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 283 for field in cls_fields: 284 try: 285 results[field.name] = self.validate_field_type(field, on_typecheck_error) 286 except Exception as e: 287 results[field.name] = False 288 exceptions[field.name] = e 289 290 # figure out what to do with the exceptions 291 if len(exceptions) > 0: 292 on_typecheck_error.process( 293 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 294 + "\n\t" 295 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 296 except_cls=ValueError, 297 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 298 except_from=list(exceptions.values())[0], 299 ) 300 301 return results
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
returns a dict of field names to bools, where the bool is if the field type is valid
304def SerializableDataclass__validate_fields_types( 305 self: SerializableDataclass, 306 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 307) -> bool: 308 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 309 return all( 310 SerializableDataclass__validate_fields_types__dict( 311 self, on_typecheck_error=on_typecheck_error 312 ).values() 313 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
316@dataclass_transform( 317 field_specifiers=(serializable_field, SerializableField), 318) 319class SerializableDataclass(abc.ABC): 320 """Base class for serializable dataclasses 321 322 only for linting and type checking, still need to call `serializable_dataclass` decorator 323 324 # Usage: 325 326 ```python 327 @serializable_dataclass 328 class MyClass(SerializableDataclass): 329 a: int 330 b: str 331 ``` 332 333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 334 335 >>> my_obj = MyClass(a=1, b="q") 336 >>> s = json.dumps(my_obj.serialize()) 337 >>> s 338 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 339 >>> read_obj = MyClass.load(json.loads(s)) 340 >>> read_obj == my_obj 341 True 342 343 This isn't too impressive on its own, but it gets more useful when you have nested classses, 344 or fields that are not json-serializable by default: 345 346 ```python 347 @serializable_dataclass 348 class NestedClass(SerializableDataclass): 349 x: str 350 y: MyClass 351 act_fun: torch.nn.Module = serializable_field( 352 default=torch.nn.ReLU(), 353 serialization_fn=lambda x: str(x), 354 deserialize_fn=lambda x: getattr(torch.nn, x)(), 355 ) 356 ``` 357 358 which gives us: 359 360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 361 >>> s = json.dumps(nc.serialize()) 362 >>> s 363 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 364 >>> read_nc = NestedClass.load(json.loads(s)) 365 >>> read_nc == nc 366 True 367 """ 368 369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 ) 374 375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 379 380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 ) 387 388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 ) 397 398 def __eq__(self, other: Any) -> bool: 399 return dc_eq(self, other) 400 401 def __hash__(self) -> int: 402 "hashes the json-serialized representation of the class" 403 return hash(json.dumps(self.serialize())) 404 405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 try: 443 if self == other: 444 return diff_result 445 except Exception: 446 pass 447 448 # if we are working with serialized data, serialize the instances 449 if of_serialized: 450 ser_self: dict = self.serialize() 451 ser_other: dict = other.serialize() 452 453 # for each field in the class 454 for field in dataclasses.fields(self): # type: ignore[arg-type] 455 # skip fields that are not for comparison 456 if not field.compare: 457 continue 458 459 # get values 460 field_name: str = field.name 461 self_value = getattr(self, field_name) 462 other_value = getattr(other, field_name) 463 464 # if the values are both serializable dataclasses, recurse 465 if isinstance(self_value, SerializableDataclass) and isinstance( 466 other_value, SerializableDataclass 467 ): 468 nested_diff: dict = self_value.diff( 469 other_value, of_serialized=of_serialized 470 ) 471 if nested_diff: 472 diff_result[field_name] = nested_diff 473 # only support serializable dataclasses 474 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 475 other_value 476 ): 477 raise ValueError("Non-serializable dataclass is not supported") 478 else: 479 # get the values of either the serialized or the actual values 480 self_value_s = ser_self[field_name] if of_serialized else self_value 481 other_value_s = ser_other[field_name] if of_serialized else other_value 482 # compare the values 483 if not array_safe_eq(self_value_s, other_value_s): 484 diff_result[field_name] = {"self": self_value, "other": other_value} 485 486 # return the diff result 487 return diff_result 488 489 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 490 """update the instance from a nested dict, useful for configuration from command line args 491 492 # Parameters: 493 - `nested_dict : dict[str, Any]` 494 nested dict to update the instance with 495 """ 496 for field in dataclasses.fields(self): # type: ignore[arg-type] 497 field_name: str = field.name 498 self_value = getattr(self, field_name) 499 500 if field_name in nested_dict: 501 if isinstance(self_value, SerializableDataclass): 502 self_value.update_from_nested_dict(nested_dict[field_name]) 503 else: 504 setattr(self, field_name, nested_dict[field_name]) 505 506 def __copy__(self) -> "SerializableDataclass": 507 "deep copy by serializing and loading the instance to json" 508 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 509 510 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 511 "deep copy by serializing and loading the instance to json" 512 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 )
given a dataclass, check the field matches the type hint
405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 try: 443 if self == other: 444 return diff_result 445 except Exception: 446 pass 447 448 # if we are working with serialized data, serialize the instances 449 if of_serialized: 450 ser_self: dict = self.serialize() 451 ser_other: dict = other.serialize() 452 453 # for each field in the class 454 for field in dataclasses.fields(self): # type: ignore[arg-type] 455 # skip fields that are not for comparison 456 if not field.compare: 457 continue 458 459 # get values 460 field_name: str = field.name 461 self_value = getattr(self, field_name) 462 other_value = getattr(other, field_name) 463 464 # if the values are both serializable dataclasses, recurse 465 if isinstance(self_value, SerializableDataclass) and isinstance( 466 other_value, SerializableDataclass 467 ): 468 nested_diff: dict = self_value.diff( 469 other_value, of_serialized=of_serialized 470 ) 471 if nested_diff: 472 diff_result[field_name] = nested_diff 473 # only support serializable dataclasses 474 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 475 other_value 476 ): 477 raise ValueError("Non-serializable dataclass is not supported") 478 else: 479 # get the values of either the serialized or the actual values 480 self_value_s = ser_self[field_name] if of_serialized else self_value 481 other_value_s = ser_other[field_name] if of_serialized else other_value 482 # compare the values 483 if not array_safe_eq(self_value_s, other_value_s): 484 diff_result[field_name] = {"self": self_value, "other": other_value} 485 486 # return the diff result 487 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
489 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 490 """update the instance from a nested dict, useful for configuration from command line args 491 492 # Parameters: 493 - `nested_dict : dict[str, Any]` 494 nested dict to update the instance with 495 """ 496 for field in dataclasses.fields(self): # type: ignore[arg-type] 497 field_name: str = field.name 498 self_value = getattr(self, field_name) 499 500 if field_name in nested_dict: 501 if isinstance(self_value, SerializableDataclass): 502 self_value.update_from_nested_dict(nested_dict[field_name]) 503 else: 504 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
517@functools.lru_cache(typed=True) 518def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 519 "cached typing.get_type_hints for a class" 520 return typing.get_type_hints(cls)
cached typing.get_type_hints for a class
523def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 524 "helper function to get type hints for a class" 525 cls_type_hints: dict[str, Any] 526 try: 527 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 528 if len(cls_type_hints) == 0: 529 cls_type_hints = typing.get_type_hints(cls) 530 531 if len(cls_type_hints) == 0: 532 raise ValueError(f"empty type hints for {cls.__name__ = }") 533 except (TypeError, NameError, ValueError) as e: 534 raise TypeError( 535 f"Cannot get type hints for {cls = }\n" 536 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 537 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 538 + f" {e = }" 539 ) from e 540 541 return cls_type_hints
helper function to get type hints for a class
544class KWOnlyError(NotImplementedError): 545 "kw-only dataclasses are not supported in python <3.9" 546 547 pass
kw-only dataclasses are not supported in python <3.9
Inherited Members
- builtins.NotImplementedError
- NotImplementedError
- builtins.BaseException
- with_traceback
- add_note
- args
base class for field errors
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
556class NotSerializableFieldException(FieldError): 557 "field is not a `SerializableField`" 558 559 pass
field is not a SerializableField
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while serializing a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while loading a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
574class FieldTypeMismatchError(FieldError, TypeError): 575 "error when a field type does not match the type hint" 576 577 pass
error when a field type does not match the type hint
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
580@dataclass_transform( 581 field_specifiers=(serializable_field, SerializableField), 582) 583def serializable_dataclass( 584 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 585 _cls=None, # type: ignore 586 *, 587 init: bool = True, 588 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 589 eq: bool = True, 590 order: bool = False, 591 unsafe_hash: bool = False, 592 frozen: bool = False, 593 properties_to_serialize: Optional[list[str]] = None, 594 register_handler: bool = True, 595 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 596 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 597 methods_no_override: list[str] | None = None, 598 **kwargs, 599): 600 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 601 602 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 603 604 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 605 606 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 607 608 Examines PEP 526 `__annotations__` to determine fields. 609 610 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 611 612 ```python 613 @serializable_dataclass(kw_only=True) 614 class Myclass(SerializableDataclass): 615 a: int 616 b: str 617 ``` 618 ```python 619 >>> Myclass(a=1, b="q").serialize() 620 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 621 ``` 622 623 # Parameters: 624 625 - `_cls : _type_` 626 class to decorate. don't pass this arg, just use this as a decorator 627 (defaults to `None`) 628 - `init : bool` 629 whether to add an `__init__` method 630 *(passed to dataclasses.dataclass)* 631 (defaults to `True`) 632 - `repr : bool` 633 whether to add a `__repr__` method 634 *(passed to dataclasses.dataclass)* 635 (defaults to `True`) 636 - `order : bool` 637 whether to add rich comparison methods 638 *(passed to dataclasses.dataclass)* 639 (defaults to `False`) 640 - `unsafe_hash : bool` 641 whether to add a `__hash__` method 642 *(passed to dataclasses.dataclass)* 643 (defaults to `False`) 644 - `frozen : bool` 645 whether to make the class frozen 646 *(passed to dataclasses.dataclass)* 647 (defaults to `False`) 648 - `properties_to_serialize : Optional[list[str]]` 649 which properties to add to the serialized data dict 650 **SerializableDataclass only** 651 (defaults to `None`) 652 - `register_handler : bool` 653 if true, register the class with ZANJ for loading 654 **SerializableDataclass only** 655 (defaults to `True`) 656 - `on_typecheck_error : ErrorMode` 657 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 658 **SerializableDataclass only** 659 - `on_typecheck_mismatch : ErrorMode` 660 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 661 **SerializableDataclass only** 662 - `methods_no_override : list[str]|None` 663 list of methods that should not be overridden by the decorator 664 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 665 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 666 **SerializableDataclass only** 667 (defaults to `None`) 668 - `**kwargs` 669 *(passed to dataclasses.dataclass)* 670 671 # Returns: 672 673 - `_type_` 674 the decorated class 675 676 # Raises: 677 678 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 679 - `NotSerializableFieldException` : if a field is not a `SerializableField` 680 - `FieldSerializationError` : if there is an error serializing a field 681 - `AttributeError` : if a property is not found on the class 682 - `FieldLoadingError` : if there is an error loading a field 683 """ 684 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 685 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 686 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 687 688 if properties_to_serialize is None: 689 _properties_to_serialize: list = list() 690 else: 691 _properties_to_serialize = properties_to_serialize 692 693 def wrap(cls: Type[T]) -> Type[T]: 694 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 695 for field_name, field_type in cls.__annotations__.items(): 696 field_value = getattr(cls, field_name, None) 697 if not isinstance(field_value, SerializableField): 698 if isinstance(field_value, dataclasses.Field): 699 # Convert the field to a SerializableField while preserving properties 700 field_value = SerializableField.from_Field(field_value) 701 else: 702 # Create a new SerializableField 703 field_value = serializable_field() 704 setattr(cls, field_name, field_value) 705 706 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 707 if sys.version_info < (3, 10): 708 if "kw_only" in kwargs: 709 if kwargs["kw_only"] == True: # noqa: E712 710 raise KWOnlyError( 711 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 712 ) 713 else: 714 del kwargs["kw_only"] 715 716 # call `dataclasses.dataclass` to set some stuff up 717 cls = dataclasses.dataclass( # type: ignore[call-overload] 718 cls, 719 init=init, 720 repr=repr, 721 eq=eq, 722 order=order, 723 unsafe_hash=unsafe_hash, 724 frozen=frozen, 725 **kwargs, 726 ) 727 728 # copy these to the class 729 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 730 731 # ====================================================================== 732 # define `serialize` func 733 # done locally since it depends on args to the decorator 734 # ====================================================================== 735 def serialize(self) -> dict[str, Any]: 736 result: dict[str, Any] = { 737 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 738 } 739 # for each field in the class 740 for field in dataclasses.fields(self): # type: ignore[arg-type] 741 # need it to be our special SerializableField 742 if not isinstance(field, SerializableField): 743 raise NotSerializableFieldException( 744 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 745 f"but a {type(field)} " 746 "this state should be inaccessible, please report this bug!" 747 ) 748 749 # try to save it 750 if field.serialize: 751 try: 752 # get the val 753 value = getattr(self, field.name) 754 # if it is a serializable dataclass, serialize it 755 if isinstance(value, SerializableDataclass): 756 value = value.serialize() 757 # if the value has a serialization function, use that 758 if hasattr(value, "serialize") and callable(value.serialize): 759 value = value.serialize() 760 # if the field has a serialization function, use that 761 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 762 elif field.serialization_fn: 763 value = field.serialization_fn(value) 764 765 # store the value in the result 766 result[field.name] = value 767 except Exception as e: 768 raise FieldSerializationError( 769 "\n".join( 770 [ 771 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 772 f"{field = }", 773 f"{value = }", 774 f"{self = }", 775 ] 776 ) 777 ) from e 778 779 # store each property if we can get it 780 for prop in self._properties_to_serialize: 781 if hasattr(cls, prop): 782 value = getattr(self, prop) 783 result[prop] = value 784 else: 785 raise AttributeError( 786 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 787 + f"but it is in {self._properties_to_serialize = }" 788 + f"\n{self = }" 789 ) 790 791 return result 792 793 # ====================================================================== 794 # define `load` func 795 # done locally since it depends on args to the decorator 796 # ====================================================================== 797 # mypy thinks this isnt a classmethod 798 @classmethod # type: ignore[misc] 799 def load(cls, data: dict[str, Any] | T) -> Type[T]: 800 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 801 if isinstance(data, cls): 802 return data 803 804 assert isinstance( 805 data, typing.Mapping 806 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 807 808 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 809 810 # initialize dict for keeping what we will pass to the constructor 811 ctor_kwargs: dict[str, Any] = dict() 812 813 # iterate over the fields of the class 814 for field in dataclasses.fields(cls): 815 # check if the field is a SerializableField 816 assert isinstance( 817 field, SerializableField 818 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 819 820 # check if the field is in the data and if it should be initialized 821 if (field.name in data) and field.init: 822 # get the value, we will be processing it 823 value: Any = data[field.name] 824 825 # get the type hint for the field 826 field_type_hint: Any = cls_type_hints.get(field.name, None) 827 828 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 829 if field.deserialize_fn: 830 # if it has a deserialization function, use that 831 value = field.deserialize_fn(value) 832 elif field.loading_fn: 833 # if it has a loading function, use that 834 value = field.loading_fn(data) 835 elif ( 836 field_type_hint is not None 837 and hasattr(field_type_hint, "load") 838 and callable(field_type_hint.load) 839 ): 840 # if no loading function but has a type hint with a load method, use that 841 if isinstance(value, dict): 842 value = field_type_hint.load(value) 843 else: 844 raise FieldLoadingError( 845 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 846 ) 847 else: 848 # assume no loading needs to happen, keep `value` as-is 849 pass 850 851 # store the value in the constructor kwargs 852 ctor_kwargs[field.name] = value 853 854 # create a new instance of the class with the constructor kwargs 855 output: cls = cls(**ctor_kwargs) 856 857 # validate the types of the fields if needed 858 if on_typecheck_mismatch != ErrorMode.IGNORE: 859 fields_valid: dict[str, bool] = ( 860 SerializableDataclass__validate_fields_types__dict( 861 output, 862 on_typecheck_error=on_typecheck_error, 863 ) 864 ) 865 866 # if there are any fields that are not valid, raise an error 867 if not all(fields_valid.values()): 868 msg: str = ( 869 f"Type mismatch in fields of {cls.__name__}:\n" 870 + "\n".join( 871 [ 872 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 873 for k, v in fields_valid.items() 874 if not v 875 ] 876 ) 877 ) 878 879 on_typecheck_mismatch.process( 880 msg, except_cls=FieldTypeMismatchError 881 ) 882 883 # return the new instance 884 return output 885 886 _methods_no_override: set[str] 887 if methods_no_override is None: 888 _methods_no_override = set() 889 else: 890 _methods_no_override = set(methods_no_override) 891 892 if _methods_no_override - { 893 "__eq__", 894 "serialize", 895 "load", 896 "validate_fields_types", 897 }: 898 warnings.warn( 899 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 900 ) 901 902 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 903 if "serialize" not in _methods_no_override: 904 # type is `Callable[[T], dict]` 905 cls.serialize = serialize # type: ignore[attr-defined] 906 if "load" not in _methods_no_override: 907 # type is `Callable[[dict], T]` 908 cls.load = load # type: ignore[attr-defined] 909 910 if "validate_field_type" not in _methods_no_override: 911 # type is `Callable[[T, ErrorMode], bool]` 912 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 913 914 if "__eq__" not in _methods_no_override: 915 # type is `Callable[[T, T], bool]` 916 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 917 918 # Register the class with ZANJ 919 if register_handler: 920 zanj_register_loader_serializable_dataclass(cls) 921 922 return cls 923 924 if _cls is None: 925 return wrap 926 else: 927 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
!!
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
whether to add an__init__
method (passed to dataclasses.dataclass) (defaults toTrue
)repr : bool
whether to add a__repr__
method (passed to dataclasses.dataclass) (defaults toTrue
)order : bool
whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults toFalse
)unsafe_hash : bool
whether to add a__hash__
method (passed to dataclasses.dataclass) (defaults toFalse
)frozen : bool
whether to make the class frozen (passed to dataclasses.dataclass) (defaults toFalse
)properties_to_serialize : Optional[list[str]]
which properties to add to the serialized data dict SerializableDataclass only (defaults toNone
)register_handler : bool
if true, register the class with ZANJ for loading SerializableDataclass only (defaults toTrue
)on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode
what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
SerializableDataclass onlymethods_no_override : list[str]|None
list of methods that should not be overridden by the decorator by default,__eq__
,serialize
,load
, andvalidate_fields_types
are overridden by this function, but you can disable this if you'd rather write your own.dataclasses.dataclass
might still overwrite these, and those options take precedence SerializableDataclass only (defaults toNone
)**kwargs
(passed to dataclasses.dataclass)
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field