muutils.json_serialize.serializable_dataclass
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj)
will give you a dict, but if some fields are not json-serializable,
you will get an error when you call json.dumps(d)
. This module provides a way around that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
1"""save and load objects to and from json or compatible formats in a recoverable way 2 3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 5 6Instead, you define your class: 7 8```python 9@serializable_dataclass 10class MyClass(SerializableDataclass): 11 a: int 12 b: str 13``` 14 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 16 17 >>> my_obj = MyClass(a=1, b="q") 18 >>> s = json.dumps(my_obj.serialize()) 19 >>> s 20 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 21 >>> read_obj = MyClass.load(json.loads(s)) 22 >>> read_obj == my_obj 23 True 24 25This isn't too impressive on its own, but it gets more useful when you have nested classses, 26or fields that are not json-serializable by default: 27 28```python 29@serializable_dataclass 30class NestedClass(SerializableDataclass): 31 x: str 32 y: MyClass 33 act_fun: torch.nn.Module = serializable_field( 34 default=torch.nn.ReLU(), 35 serialization_fn=lambda x: str(x), 36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 37 ) 38``` 39 40which gives us: 41 42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 43 >>> s = json.dumps(nc.serialize()) 44 >>> s 45 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 46 >>> read_nc = NestedClass.load(json.loads(s)) 47 >>> read_nc == nc 48 True 49 50""" 51 52from __future__ import annotations 53 54import abc 55import dataclasses 56import functools 57import json 58import sys 59import typing 60import warnings 61from typing import Any, Optional, Type, TypeVar 62 63from muutils.errormode import ErrorMode 64from muutils.validate_type import validate_type 65from muutils.json_serialize.serializable_field import ( 66 SerializableField, 67 serializable_field, 68) 69from muutils.json_serialize.util import array_safe_eq, dc_eq 70 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 72 73 74def _dataclass_transform_mock( 75 *, 76 eq_default: bool = True, 77 order_default: bool = False, 78 kw_only_default: bool = False, 79 frozen_default: bool = False, 80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (), 81 **kwargs: Any, 82) -> typing.Callable: 83 "mock `typing.dataclass_transform` for python <3.11" 84 85 def decorator(cls_or_fn): 86 cls_or_fn.__dataclass_transform__ = { 87 "eq_default": eq_default, 88 "order_default": order_default, 89 "kw_only_default": kw_only_default, 90 "frozen_default": frozen_default, 91 "field_specifiers": field_specifiers, 92 "kwargs": kwargs, 93 } 94 return cls_or_fn 95 96 return decorator 97 98 99dataclass_transform: typing.Callable 100if sys.version_info < (3, 11): 101 dataclass_transform = _dataclass_transform_mock 102else: 103 dataclass_transform = typing.dataclass_transform 104 105 106T = TypeVar("T") 107 108 109class CantGetTypeHintsWarning(UserWarning): 110 "special warning for when we can't get type hints" 111 112 pass 113 114 115class ZanjMissingWarning(UserWarning): 116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 117 118 pass 119 120 121_zanj_loading_needs_import: bool = True 122"flag to keep track of if we have successfully imported ZANJ" 123 124 125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 126 """Register a serializable dataclass with the ZANJ import 127 128 this allows `ZANJ().read()` to load the class and not just return plain dicts 129 130 131 # TODO: there is some duplication here with register_loader_handler 132 """ 133 global _zanj_loading_needs_import 134 135 if _zanj_loading_needs_import: 136 try: 137 from zanj.loading import ( # type: ignore[import] 138 LoaderHandler, 139 register_loader_handler, 140 ) 141 except ImportError: 142 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 143 # warnings.warn( 144 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 145 # ZanjMissingWarning, 146 # ) 147 return 148 149 _format: str = f"{cls.__name__}(SerializableDataclass)" 150 lh: LoaderHandler = LoaderHandler( 151 check=lambda json_item, path=None, z=None: ( # type: ignore 152 isinstance(json_item, dict) 153 and "__format__" in json_item 154 and json_item["__format__"].startswith(_format) 155 ), 156 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 157 uid=_format, 158 source_pckg=cls.__module__, 159 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 160 ) 161 162 register_loader_handler(lh) 163 164 return lh 165 166 167_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 168_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 169 170 171class FieldIsNotInitOrSerializeWarning(UserWarning): 172 pass 173 174 175def SerializableDataclass__validate_field_type( 176 self: SerializableDataclass, 177 field: SerializableField | str, 178 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 179) -> bool: 180 """given a dataclass, check the field matches the type hint 181 182 this function is written to `SerializableDataclass.validate_field_type` 183 184 # Parameters: 185 - `self : SerializableDataclass` 186 `SerializableDataclass` instance 187 - `field : SerializableField | str` 188 field to validate, will get from `self.__dataclass_fields__` if an `str` 189 - `on_typecheck_error : ErrorMode` 190 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 191 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 192 193 # Returns: 194 - `bool` 195 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 196 """ 197 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 198 199 # get field 200 _field: SerializableField 201 if isinstance(field, str): 202 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 203 else: 204 _field = field 205 206 # do nothing case 207 if not _field.assert_type: 208 return True 209 210 # if field is not `init` or not `serialize`, skip but warn 211 # TODO: how to handle fields which are not `init` or `serialize`? 212 if not _field.init or not _field.serialize: 213 warnings.warn( 214 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 215 FieldIsNotInitOrSerializeWarning, 216 ) 217 return True 218 219 assert isinstance( 220 _field, SerializableField 221 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 222 223 # get field type hints 224 try: 225 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 226 except KeyError as e: 227 on_typecheck_error.process( 228 ( 229 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 230 + f"{get_cls_type_hints(self.__class__) = }\n" 231 + f"Python version is {sys.version_info = }. You can:\n" 232 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 233 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 234 + " - use python 3.9.x or higher\n" 235 + " - specify custom type validation function via `custom_typecheck_fn`\n" 236 ), 237 except_cls=TypeError, 238 except_from=e, 239 ) 240 return False 241 242 # get the value 243 value: Any = getattr(self, _field.name) 244 245 # validate the type 246 try: 247 type_is_valid: bool 248 # validate the type with the default type validator 249 if _field.custom_typecheck_fn is None: 250 type_is_valid = validate_type(value, field_type_hint) 251 # validate the type with a custom type validator 252 else: 253 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 254 255 return type_is_valid 256 257 except Exception as e: 258 on_typecheck_error.process( 259 "exception while validating type: " 260 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 261 except_cls=ValueError, 262 except_from=e, 263 ) 264 return False 265 266 267def SerializableDataclass__validate_fields_types__dict( 268 self: SerializableDataclass, 269 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 270) -> dict[str, bool]: 271 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 272 273 returns a dict of field names to bools, where the bool is if the field type is valid 274 """ 275 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 276 277 # if except, bundle the exceptions 278 results: dict[str, bool] = dict() 279 exceptions: dict[str, Exception] = dict() 280 281 # for each field in the class 282 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 283 for field in cls_fields: 284 try: 285 results[field.name] = self.validate_field_type(field, on_typecheck_error) 286 except Exception as e: 287 results[field.name] = False 288 exceptions[field.name] = e 289 290 # figure out what to do with the exceptions 291 if len(exceptions) > 0: 292 on_typecheck_error.process( 293 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 294 + "\n\t" 295 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 296 except_cls=ValueError, 297 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 298 except_from=list(exceptions.values())[0], 299 ) 300 301 return results 302 303 304def SerializableDataclass__validate_fields_types( 305 self: SerializableDataclass, 306 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 307) -> bool: 308 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 309 return all( 310 SerializableDataclass__validate_fields_types__dict( 311 self, on_typecheck_error=on_typecheck_error 312 ).values() 313 ) 314 315 316@dataclass_transform( 317 field_specifiers=(serializable_field, SerializableField), 318) 319class SerializableDataclass(abc.ABC): 320 """Base class for serializable dataclasses 321 322 only for linting and type checking, still need to call `serializable_dataclass` decorator 323 324 # Usage: 325 326 ```python 327 @serializable_dataclass 328 class MyClass(SerializableDataclass): 329 a: int 330 b: str 331 ``` 332 333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 334 335 >>> my_obj = MyClass(a=1, b="q") 336 >>> s = json.dumps(my_obj.serialize()) 337 >>> s 338 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 339 >>> read_obj = MyClass.load(json.loads(s)) 340 >>> read_obj == my_obj 341 True 342 343 This isn't too impressive on its own, but it gets more useful when you have nested classses, 344 or fields that are not json-serializable by default: 345 346 ```python 347 @serializable_dataclass 348 class NestedClass(SerializableDataclass): 349 x: str 350 y: MyClass 351 act_fun: torch.nn.Module = serializable_field( 352 default=torch.nn.ReLU(), 353 serialization_fn=lambda x: str(x), 354 deserialize_fn=lambda x: getattr(torch.nn, x)(), 355 ) 356 ``` 357 358 which gives us: 359 360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 361 >>> s = json.dumps(nc.serialize()) 362 >>> s 363 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 364 >>> read_nc = NestedClass.load(json.loads(s)) 365 >>> read_nc == nc 366 True 367 """ 368 369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 ) 374 375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 379 380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 ) 387 388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 ) 397 398 def __eq__(self, other: Any) -> bool: 399 return dc_eq(self, other) 400 401 def __hash__(self) -> int: 402 "hashes the json-serialized representation of the class" 403 return hash(json.dumps(self.serialize())) 404 405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 if self == other: 443 return diff_result 444 445 # if we are working with serialized data, serialize the instances 446 if of_serialized: 447 ser_self: dict = self.serialize() 448 ser_other: dict = other.serialize() 449 450 # for each field in the class 451 for field in dataclasses.fields(self): # type: ignore[arg-type] 452 # skip fields that are not for comparison 453 if not field.compare: 454 continue 455 456 # get values 457 field_name: str = field.name 458 self_value = getattr(self, field_name) 459 other_value = getattr(other, field_name) 460 461 # if the values are both serializable dataclasses, recurse 462 if isinstance(self_value, SerializableDataclass) and isinstance( 463 other_value, SerializableDataclass 464 ): 465 nested_diff: dict = self_value.diff( 466 other_value, of_serialized=of_serialized 467 ) 468 if nested_diff: 469 diff_result[field_name] = nested_diff 470 # only support serializable dataclasses 471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 472 other_value 473 ): 474 raise ValueError("Non-serializable dataclass is not supported") 475 else: 476 # get the values of either the serialized or the actual values 477 self_value_s = ser_self[field_name] if of_serialized else self_value 478 other_value_s = ser_other[field_name] if of_serialized else other_value 479 # compare the values 480 if not array_safe_eq(self_value_s, other_value_s): 481 diff_result[field_name] = {"self": self_value, "other": other_value} 482 483 # return the diff result 484 return diff_result 485 486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 487 """update the instance from a nested dict, useful for configuration from command line args 488 489 # Parameters: 490 - `nested_dict : dict[str, Any]` 491 nested dict to update the instance with 492 """ 493 for field in dataclasses.fields(self): # type: ignore[arg-type] 494 field_name: str = field.name 495 self_value = getattr(self, field_name) 496 497 if field_name in nested_dict: 498 if isinstance(self_value, SerializableDataclass): 499 self_value.update_from_nested_dict(nested_dict[field_name]) 500 else: 501 setattr(self, field_name, nested_dict[field_name]) 502 503 def __copy__(self) -> "SerializableDataclass": 504 "deep copy by serializing and loading the instance to json" 505 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 506 507 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 508 "deep copy by serializing and loading the instance to json" 509 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 510 511 512# cache this so we don't have to keep getting it 513# TODO: are the types hashable? does this even make sense? 514@functools.lru_cache(typed=True) 515def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 516 "cached typing.get_type_hints for a class" 517 return typing.get_type_hints(cls) 518 519 520def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 521 "helper function to get type hints for a class" 522 cls_type_hints: dict[str, Any] 523 try: 524 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 525 if len(cls_type_hints) == 0: 526 cls_type_hints = typing.get_type_hints(cls) 527 528 if len(cls_type_hints) == 0: 529 raise ValueError(f"empty type hints for {cls.__name__ = }") 530 except (TypeError, NameError, ValueError) as e: 531 raise TypeError( 532 f"Cannot get type hints for {cls = }\n" 533 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 534 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 535 + f" {e = }" 536 ) from e 537 538 return cls_type_hints 539 540 541class KWOnlyError(NotImplementedError): 542 "kw-only dataclasses are not supported in python <3.9" 543 544 pass 545 546 547class FieldError(ValueError): 548 "base class for field errors" 549 550 pass 551 552 553class NotSerializableFieldException(FieldError): 554 "field is not a `SerializableField`" 555 556 pass 557 558 559class FieldSerializationError(FieldError): 560 "error while serializing a field" 561 562 pass 563 564 565class FieldLoadingError(FieldError): 566 "error while loading a field" 567 568 pass 569 570 571class FieldTypeMismatchError(FieldError, TypeError): 572 "error when a field type does not match the type hint" 573 574 pass 575 576 577@dataclass_transform( 578 field_specifiers=(serializable_field, SerializableField), 579) 580def serializable_dataclass( 581 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 582 _cls=None, # type: ignore 583 *, 584 init: bool = True, 585 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 586 eq: bool = True, 587 order: bool = False, 588 unsafe_hash: bool = False, 589 frozen: bool = False, 590 properties_to_serialize: Optional[list[str]] = None, 591 register_handler: bool = True, 592 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 593 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 594 **kwargs, 595): 596 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 597 598 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 599 600 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 601 602 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 603 604 Examines PEP 526 `__annotations__` to determine fields. 605 606 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 607 608 ```python 609 @serializable_dataclass(kw_only=True) 610 class Myclass(SerializableDataclass): 611 a: int 612 b: str 613 ``` 614 ```python 615 >>> Myclass(a=1, b="q").serialize() 616 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 617 ``` 618 619 # Parameters: 620 - `_cls : _type_` 621 class to decorate. don't pass this arg, just use this as a decorator 622 (defaults to `None`) 623 - `init : bool` 624 (defaults to `True`) 625 - `repr : bool` 626 (defaults to `True`) 627 - `order : bool` 628 (defaults to `False`) 629 - `unsafe_hash : bool` 630 (defaults to `False`) 631 - `frozen : bool` 632 (defaults to `False`) 633 - `properties_to_serialize : Optional[list[str]]` 634 **SerializableDataclass only:** which properties to add to the serialized data dict 635 (defaults to `None`) 636 - `register_handler : bool` 637 **SerializableDataclass only:** if true, register the class with ZANJ for loading 638 (defaults to `True`) 639 - `on_typecheck_error : ErrorMode` 640 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 641 - `on_typecheck_mismatch : ErrorMode` 642 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 643 644 # Returns: 645 - `_type_` 646 the decorated class 647 648 # Raises: 649 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 650 - `NotSerializableFieldException` : if a field is not a `SerializableField` 651 - `FieldSerializationError` : if there is an error serializing a field 652 - `AttributeError` : if a property is not found on the class 653 - `FieldLoadingError` : if there is an error loading a field 654 """ 655 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 656 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 657 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 658 659 if properties_to_serialize is None: 660 _properties_to_serialize: list = list() 661 else: 662 _properties_to_serialize = properties_to_serialize 663 664 def wrap(cls: Type[T]) -> Type[T]: 665 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 666 for field_name, field_type in cls.__annotations__.items(): 667 field_value = getattr(cls, field_name, None) 668 if not isinstance(field_value, SerializableField): 669 if isinstance(field_value, dataclasses.Field): 670 # Convert the field to a SerializableField while preserving properties 671 field_value = SerializableField.from_Field(field_value) 672 else: 673 # Create a new SerializableField 674 field_value = serializable_field() 675 setattr(cls, field_name, field_value) 676 677 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 678 if sys.version_info < (3, 10): 679 if "kw_only" in kwargs: 680 if kwargs["kw_only"] == True: # noqa: E712 681 raise KWOnlyError("kw_only is not supported in python >=3.9") 682 else: 683 del kwargs["kw_only"] 684 685 # call `dataclasses.dataclass` to set some stuff up 686 cls = dataclasses.dataclass( # type: ignore[call-overload] 687 cls, 688 init=init, 689 repr=repr, 690 eq=eq, 691 order=order, 692 unsafe_hash=unsafe_hash, 693 frozen=frozen, 694 **kwargs, 695 ) 696 697 # copy these to the class 698 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 699 700 # ====================================================================== 701 # define `serialize` func 702 # done locally since it depends on args to the decorator 703 # ====================================================================== 704 def serialize(self) -> dict[str, Any]: 705 result: dict[str, Any] = { 706 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 707 } 708 # for each field in the class 709 for field in dataclasses.fields(self): # type: ignore[arg-type] 710 # need it to be our special SerializableField 711 if not isinstance(field, SerializableField): 712 raise NotSerializableFieldException( 713 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 714 f"but a {type(field)} " 715 "this state should be inaccessible, please report this bug!" 716 ) 717 718 # try to save it 719 if field.serialize: 720 try: 721 # get the val 722 value = getattr(self, field.name) 723 # if it is a serializable dataclass, serialize it 724 if isinstance(value, SerializableDataclass): 725 value = value.serialize() 726 # if the value has a serialization function, use that 727 if hasattr(value, "serialize") and callable(value.serialize): 728 value = value.serialize() 729 # if the field has a serialization function, use that 730 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 731 elif field.serialization_fn: 732 value = field.serialization_fn(value) 733 734 # store the value in the result 735 result[field.name] = value 736 except Exception as e: 737 raise FieldSerializationError( 738 "\n".join( 739 [ 740 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 741 f"{field = }", 742 f"{value = }", 743 f"{self = }", 744 ] 745 ) 746 ) from e 747 748 # store each property if we can get it 749 for prop in self._properties_to_serialize: 750 if hasattr(cls, prop): 751 value = getattr(self, prop) 752 result[prop] = value 753 else: 754 raise AttributeError( 755 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 756 + f"but it is in {self._properties_to_serialize = }" 757 + f"\n{self = }" 758 ) 759 760 return result 761 762 # ====================================================================== 763 # define `load` func 764 # done locally since it depends on args to the decorator 765 # ====================================================================== 766 # mypy thinks this isnt a classmethod 767 @classmethod # type: ignore[misc] 768 def load(cls, data: dict[str, Any] | T) -> Type[T]: 769 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 770 if isinstance(data, cls): 771 return data 772 773 assert isinstance( 774 data, typing.Mapping 775 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 776 777 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 778 779 # initialize dict for keeping what we will pass to the constructor 780 ctor_kwargs: dict[str, Any] = dict() 781 782 # iterate over the fields of the class 783 for field in dataclasses.fields(cls): 784 # check if the field is a SerializableField 785 assert isinstance( 786 field, SerializableField 787 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 788 789 # check if the field is in the data and if it should be initialized 790 if (field.name in data) and field.init: 791 # get the value, we will be processing it 792 value: Any = data[field.name] 793 794 # get the type hint for the field 795 field_type_hint: Any = cls_type_hints.get(field.name, None) 796 797 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 798 if field.deserialize_fn: 799 # if it has a deserialization function, use that 800 value = field.deserialize_fn(value) 801 elif field.loading_fn: 802 # if it has a loading function, use that 803 value = field.loading_fn(data) 804 elif ( 805 field_type_hint is not None 806 and hasattr(field_type_hint, "load") 807 and callable(field_type_hint.load) 808 ): 809 # if no loading function but has a type hint with a load method, use that 810 if isinstance(value, dict): 811 value = field_type_hint.load(value) 812 else: 813 raise FieldLoadingError( 814 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 815 ) 816 else: 817 # assume no loading needs to happen, keep `value` as-is 818 pass 819 820 # store the value in the constructor kwargs 821 ctor_kwargs[field.name] = value 822 823 # create a new instance of the class with the constructor kwargs 824 output: cls = cls(**ctor_kwargs) 825 826 # validate the types of the fields if needed 827 if on_typecheck_mismatch != ErrorMode.IGNORE: 828 fields_valid: dict[str, bool] = ( 829 SerializableDataclass__validate_fields_types__dict( 830 output, 831 on_typecheck_error=on_typecheck_error, 832 ) 833 ) 834 835 # if there are any fields that are not valid, raise an error 836 if not all(fields_valid.values()): 837 msg: str = ( 838 f"Type mismatch in fields of {cls.__name__}:\n" 839 + "\n".join( 840 [ 841 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 842 for k, v in fields_valid.items() 843 if not v 844 ] 845 ) 846 ) 847 848 on_typecheck_mismatch.process( 849 msg, except_cls=FieldTypeMismatchError 850 ) 851 852 # return the new instance 853 return output 854 855 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 856 # type is `Callable[[T], dict]` 857 cls.serialize = serialize # type: ignore[attr-defined] 858 # type is `Callable[[dict], T]` 859 cls.load = load # type: ignore[attr-defined] 860 # type is `Callable[[T, ErrorMode], bool]` 861 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 862 863 # type is `Callable[[T, T], bool]` 864 if not hasattr(cls, "__eq__"): 865 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 866 867 # Register the class with ZANJ 868 if register_handler: 869 zanj_register_loader_serializable_dataclass(cls) 870 871 return cls 872 873 if _cls is None: 874 return wrap 875 else: 876 return wrap(_cls)
3275def dataclass_transform( 3276 *, 3277 eq_default: bool = True, 3278 order_default: bool = False, 3279 kw_only_default: bool = False, 3280 frozen_default: bool = False, 3281 field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), 3282 **kwargs: Any, 3283) -> _IdentityCallable: 3284 """Decorator to mark an object as providing dataclass-like behaviour. 3285 3286 The decorator can be applied to a function, class, or metaclass. 3287 3288 Example usage with a decorator function:: 3289 3290 @dataclass_transform() 3291 def create_model[T](cls: type[T]) -> type[T]: 3292 ... 3293 return cls 3294 3295 @create_model 3296 class CustomerModel: 3297 id: int 3298 name: str 3299 3300 On a base class:: 3301 3302 @dataclass_transform() 3303 class ModelBase: ... 3304 3305 class CustomerModel(ModelBase): 3306 id: int 3307 name: str 3308 3309 On a metaclass:: 3310 3311 @dataclass_transform() 3312 class ModelMeta(type): ... 3313 3314 class ModelBase(metaclass=ModelMeta): ... 3315 3316 class CustomerModel(ModelBase): 3317 id: int 3318 name: str 3319 3320 The ``CustomerModel`` classes defined above will 3321 be treated by type checkers similarly to classes created with 3322 ``@dataclasses.dataclass``. 3323 For example, type checkers will assume these classes have 3324 ``__init__`` methods that accept ``id`` and ``name``. 3325 3326 The arguments to this decorator can be used to customize this behavior: 3327 - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be 3328 ``True`` or ``False`` if it is omitted by the caller. 3329 - ``order_default`` indicates whether the ``order`` parameter is 3330 assumed to be True or False if it is omitted by the caller. 3331 - ``kw_only_default`` indicates whether the ``kw_only`` parameter is 3332 assumed to be True or False if it is omitted by the caller. 3333 - ``frozen_default`` indicates whether the ``frozen`` parameter is 3334 assumed to be True or False if it is omitted by the caller. 3335 - ``field_specifiers`` specifies a static list of supported classes 3336 or functions that describe fields, similar to ``dataclasses.field()``. 3337 - Arbitrary other keyword arguments are accepted in order to allow for 3338 possible future extensions. 3339 3340 At runtime, this decorator records its arguments in the 3341 ``__dataclass_transform__`` attribute on the decorated object. 3342 It has no other runtime effect. 3343 3344 See PEP 681 for more details. 3345 """ 3346 def decorator(cls_or_fn): 3347 cls_or_fn.__dataclass_transform__ = { 3348 "eq_default": eq_default, 3349 "order_default": order_default, 3350 "kw_only_default": kw_only_default, 3351 "frozen_default": frozen_default, 3352 "field_specifiers": field_specifiers, 3353 "kwargs": kwargs, 3354 } 3355 return cls_or_fn 3356 return decorator
Decorator to mark an object as providing dataclass-like behaviour.
The decorator can be applied to a function, class, or metaclass.
Example usage with a decorator function::
@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
...
return cls
@create_model
class CustomerModel:
id: int
name: str
On a base class::
@dataclass_transform()
class ModelBase: ...
class CustomerModel(ModelBase):
id: int
name: str
On a metaclass::
@dataclass_transform()
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class CustomerModel(ModelBase):
id: int
name: str
The CustomerModel
classes defined above will
be treated by type checkers similarly to classes created with
@dataclasses.dataclass
.
For example, type checkers will assume these classes have
__init__
methods that accept id
and name
.
The arguments to this decorator can be used to customize this behavior:
eq_default
indicates whether theeq
parameter is assumed to beTrue
orFalse
if it is omitted by the caller.order_default
indicates whether theorder
parameter is assumed to be True or False if it is omitted by the caller.kw_only_default
indicates whether thekw_only
parameter is assumed to be True or False if it is omitted by the caller.frozen_default
indicates whether thefrozen
parameter is assumed to be True or False if it is omitted by the caller.field_specifiers
specifies a static list of supported classes or functions that describe fields, similar todataclasses.field()
.- Arbitrary other keyword arguments are accepted in order to allow for possible future extensions.
At runtime, this decorator records its arguments in the
__dataclass_transform__
attribute on the decorated object.
It has no other runtime effect.
See PEP 681 for more details.
110class CantGetTypeHintsWarning(UserWarning): 111 "special warning for when we can't get type hints" 112 113 pass
special warning for when we can't get type hints
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
116class ZanjMissingWarning(UserWarning): 117 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 118 119 pass
special warning for when ZANJ
is missing -- register_loader_serializable_dataclass
will not work
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
126def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 127 """Register a serializable dataclass with the ZANJ import 128 129 this allows `ZANJ().read()` to load the class and not just return plain dicts 130 131 132 # TODO: there is some duplication here with register_loader_handler 133 """ 134 global _zanj_loading_needs_import 135 136 if _zanj_loading_needs_import: 137 try: 138 from zanj.loading import ( # type: ignore[import] 139 LoaderHandler, 140 register_loader_handler, 141 ) 142 except ImportError: 143 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 144 # warnings.warn( 145 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 146 # ZanjMissingWarning, 147 # ) 148 return 149 150 _format: str = f"{cls.__name__}(SerializableDataclass)" 151 lh: LoaderHandler = LoaderHandler( 152 check=lambda json_item, path=None, z=None: ( # type: ignore 153 isinstance(json_item, dict) 154 and "__format__" in json_item 155 and json_item["__format__"].startswith(_format) 156 ), 157 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 158 uid=_format, 159 source_pckg=cls.__module__, 160 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 161 ) 162 163 register_loader_handler(lh) 164 165 return lh
Register a serializable dataclass with the ZANJ import
this allows ZANJ().read()
to load the class and not just return plain dicts
TODO: there is some duplication here with register_loader_handler
Base class for warnings generated by user code.
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
176def SerializableDataclass__validate_field_type( 177 self: SerializableDataclass, 178 field: SerializableField | str, 179 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 180) -> bool: 181 """given a dataclass, check the field matches the type hint 182 183 this function is written to `SerializableDataclass.validate_field_type` 184 185 # Parameters: 186 - `self : SerializableDataclass` 187 `SerializableDataclass` instance 188 - `field : SerializableField | str` 189 field to validate, will get from `self.__dataclass_fields__` if an `str` 190 - `on_typecheck_error : ErrorMode` 191 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 192 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 193 194 # Returns: 195 - `bool` 196 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 197 """ 198 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 199 200 # get field 201 _field: SerializableField 202 if isinstance(field, str): 203 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 204 else: 205 _field = field 206 207 # do nothing case 208 if not _field.assert_type: 209 return True 210 211 # if field is not `init` or not `serialize`, skip but warn 212 # TODO: how to handle fields which are not `init` or `serialize`? 213 if not _field.init or not _field.serialize: 214 warnings.warn( 215 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 216 FieldIsNotInitOrSerializeWarning, 217 ) 218 return True 219 220 assert isinstance( 221 _field, SerializableField 222 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 223 224 # get field type hints 225 try: 226 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 227 except KeyError as e: 228 on_typecheck_error.process( 229 ( 230 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 231 + f"{get_cls_type_hints(self.__class__) = }\n" 232 + f"Python version is {sys.version_info = }. You can:\n" 233 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 234 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 235 + " - use python 3.9.x or higher\n" 236 + " - specify custom type validation function via `custom_typecheck_fn`\n" 237 ), 238 except_cls=TypeError, 239 except_from=e, 240 ) 241 return False 242 243 # get the value 244 value: Any = getattr(self, _field.name) 245 246 # validate the type 247 try: 248 type_is_valid: bool 249 # validate the type with the default type validator 250 if _field.custom_typecheck_fn is None: 251 type_is_valid = validate_type(value, field_type_hint) 252 # validate the type with a custom type validator 253 else: 254 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 255 256 return type_is_valid 257 258 except Exception as e: 259 on_typecheck_error.process( 260 "exception while validating type: " 261 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 262 except_cls=ValueError, 263 except_from=e, 264 ) 265 return False
given a dataclass, check the field matches the type hint
this function is written to SerializableDataclass.validate_field_type
Parameters:
self : SerializableDataclass
SerializableDataclass
instancefield : SerializableField | str
field to validate, will get fromself.__dataclass_fields__
if anstr
on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, the function will returnFalse
(defaults to_DEFAULT_ON_TYPECHECK_ERROR
)
Returns:
bool
if the field type is correct.False
if the field type is incorrect or an exception is thrown andon_typecheck_error
isignore
268def SerializableDataclass__validate_fields_types__dict( 269 self: SerializableDataclass, 270 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 271) -> dict[str, bool]: 272 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 273 274 returns a dict of field names to bools, where the bool is if the field type is valid 275 """ 276 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 277 278 # if except, bundle the exceptions 279 results: dict[str, bool] = dict() 280 exceptions: dict[str, Exception] = dict() 281 282 # for each field in the class 283 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 284 for field in cls_fields: 285 try: 286 results[field.name] = self.validate_field_type(field, on_typecheck_error) 287 except Exception as e: 288 results[field.name] = False 289 exceptions[field.name] = e 290 291 # figure out what to do with the exceptions 292 if len(exceptions) > 0: 293 on_typecheck_error.process( 294 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 295 + "\n\t" 296 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 297 except_cls=ValueError, 298 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 299 except_from=list(exceptions.values())[0], 300 ) 301 302 return results
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
returns a dict of field names to bools, where the bool is if the field type is valid
305def SerializableDataclass__validate_fields_types( 306 self: SerializableDataclass, 307 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 308) -> bool: 309 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 310 return all( 311 SerializableDataclass__validate_fields_types__dict( 312 self, on_typecheck_error=on_typecheck_error 313 ).values() 314 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
317@dataclass_transform( 318 field_specifiers=(serializable_field, SerializableField), 319) 320class SerializableDataclass(abc.ABC): 321 """Base class for serializable dataclasses 322 323 only for linting and type checking, still need to call `serializable_dataclass` decorator 324 325 # Usage: 326 327 ```python 328 @serializable_dataclass 329 class MyClass(SerializableDataclass): 330 a: int 331 b: str 332 ``` 333 334 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 335 336 >>> my_obj = MyClass(a=1, b="q") 337 >>> s = json.dumps(my_obj.serialize()) 338 >>> s 339 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 340 >>> read_obj = MyClass.load(json.loads(s)) 341 >>> read_obj == my_obj 342 True 343 344 This isn't too impressive on its own, but it gets more useful when you have nested classses, 345 or fields that are not json-serializable by default: 346 347 ```python 348 @serializable_dataclass 349 class NestedClass(SerializableDataclass): 350 x: str 351 y: MyClass 352 act_fun: torch.nn.Module = serializable_field( 353 default=torch.nn.ReLU(), 354 serialization_fn=lambda x: str(x), 355 deserialize_fn=lambda x: getattr(torch.nn, x)(), 356 ) 357 ``` 358 359 which gives us: 360 361 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 362 >>> s = json.dumps(nc.serialize()) 363 >>> s 364 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 365 >>> read_nc = NestedClass.load(json.loads(s)) 366 >>> read_nc == nc 367 True 368 """ 369 370 def serialize(self) -> dict[str, Any]: 371 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 372 raise NotImplementedError( 373 f"decorate {self.__class__ = } with `@serializable_dataclass`" 374 ) 375 376 @classmethod 377 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 378 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 379 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 380 381 def validate_fields_types( 382 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 383 ) -> bool: 384 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 385 return SerializableDataclass__validate_fields_types( 386 self, on_typecheck_error=on_typecheck_error 387 ) 388 389 def validate_field_type( 390 self, 391 field: "SerializableField|str", 392 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 393 ) -> bool: 394 """given a dataclass, check the field matches the type hint""" 395 return SerializableDataclass__validate_field_type( 396 self, field, on_typecheck_error=on_typecheck_error 397 ) 398 399 def __eq__(self, other: Any) -> bool: 400 return dc_eq(self, other) 401 402 def __hash__(self) -> int: 403 "hashes the json-serialized representation of the class" 404 return hash(json.dumps(self.serialize())) 405 406 def diff( 407 self, other: "SerializableDataclass", of_serialized: bool = False 408 ) -> dict[str, Any]: 409 """get a rich and recursive diff between two instances of a serializable dataclass 410 411 ```python 412 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 413 {'b': {'self': 2, 'other': 3}} 414 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 415 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 416 ``` 417 418 # Parameters: 419 - `other : SerializableDataclass` 420 other instance to compare against 421 - `of_serialized : bool` 422 if true, compare serialized data and not raw values 423 (defaults to `False`) 424 425 # Returns: 426 - `dict[str, Any]` 427 428 429 # Raises: 430 - `ValueError` : if the instances are not of the same type 431 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 432 """ 433 # match types 434 if type(self) is not type(other): 435 raise ValueError( 436 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 437 ) 438 439 # initialize the diff result 440 diff_result: dict = {} 441 442 # if they are the same, return the empty diff 443 if self == other: 444 return diff_result 445 446 # if we are working with serialized data, serialize the instances 447 if of_serialized: 448 ser_self: dict = self.serialize() 449 ser_other: dict = other.serialize() 450 451 # for each field in the class 452 for field in dataclasses.fields(self): # type: ignore[arg-type] 453 # skip fields that are not for comparison 454 if not field.compare: 455 continue 456 457 # get values 458 field_name: str = field.name 459 self_value = getattr(self, field_name) 460 other_value = getattr(other, field_name) 461 462 # if the values are both serializable dataclasses, recurse 463 if isinstance(self_value, SerializableDataclass) and isinstance( 464 other_value, SerializableDataclass 465 ): 466 nested_diff: dict = self_value.diff( 467 other_value, of_serialized=of_serialized 468 ) 469 if nested_diff: 470 diff_result[field_name] = nested_diff 471 # only support serializable dataclasses 472 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 473 other_value 474 ): 475 raise ValueError("Non-serializable dataclass is not supported") 476 else: 477 # get the values of either the serialized or the actual values 478 self_value_s = ser_self[field_name] if of_serialized else self_value 479 other_value_s = ser_other[field_name] if of_serialized else other_value 480 # compare the values 481 if not array_safe_eq(self_value_s, other_value_s): 482 diff_result[field_name] = {"self": self_value, "other": other_value} 483 484 # return the diff result 485 return diff_result 486 487 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 488 """update the instance from a nested dict, useful for configuration from command line args 489 490 # Parameters: 491 - `nested_dict : dict[str, Any]` 492 nested dict to update the instance with 493 """ 494 for field in dataclasses.fields(self): # type: ignore[arg-type] 495 field_name: str = field.name 496 self_value = getattr(self, field_name) 497 498 if field_name in nested_dict: 499 if isinstance(self_value, SerializableDataclass): 500 self_value.update_from_nested_dict(nested_dict[field_name]) 501 else: 502 setattr(self, field_name, nested_dict[field_name]) 503 504 def __copy__(self) -> "SerializableDataclass": 505 "deep copy by serializing and loading the instance to json" 506 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 507 508 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 509 "deep copy by serializing and loading the instance to json" 510 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
370 def serialize(self) -> dict[str, Any]: 371 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 372 raise NotImplementedError( 373 f"decorate {self.__class__ = } with `@serializable_dataclass`" 374 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
376 @classmethod 377 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 378 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 379 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
381 def validate_fields_types( 382 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 383 ) -> bool: 384 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 385 return SerializableDataclass__validate_fields_types( 386 self, on_typecheck_error=on_typecheck_error 387 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
389 def validate_field_type( 390 self, 391 field: "SerializableField|str", 392 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 393 ) -> bool: 394 """given a dataclass, check the field matches the type hint""" 395 return SerializableDataclass__validate_field_type( 396 self, field, on_typecheck_error=on_typecheck_error 397 )
given a dataclass, check the field matches the type hint
406 def diff( 407 self, other: "SerializableDataclass", of_serialized: bool = False 408 ) -> dict[str, Any]: 409 """get a rich and recursive diff between two instances of a serializable dataclass 410 411 ```python 412 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 413 {'b': {'self': 2, 'other': 3}} 414 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 415 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 416 ``` 417 418 # Parameters: 419 - `other : SerializableDataclass` 420 other instance to compare against 421 - `of_serialized : bool` 422 if true, compare serialized data and not raw values 423 (defaults to `False`) 424 425 # Returns: 426 - `dict[str, Any]` 427 428 429 # Raises: 430 - `ValueError` : if the instances are not of the same type 431 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 432 """ 433 # match types 434 if type(self) is not type(other): 435 raise ValueError( 436 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 437 ) 438 439 # initialize the diff result 440 diff_result: dict = {} 441 442 # if they are the same, return the empty diff 443 if self == other: 444 return diff_result 445 446 # if we are working with serialized data, serialize the instances 447 if of_serialized: 448 ser_self: dict = self.serialize() 449 ser_other: dict = other.serialize() 450 451 # for each field in the class 452 for field in dataclasses.fields(self): # type: ignore[arg-type] 453 # skip fields that are not for comparison 454 if not field.compare: 455 continue 456 457 # get values 458 field_name: str = field.name 459 self_value = getattr(self, field_name) 460 other_value = getattr(other, field_name) 461 462 # if the values are both serializable dataclasses, recurse 463 if isinstance(self_value, SerializableDataclass) and isinstance( 464 other_value, SerializableDataclass 465 ): 466 nested_diff: dict = self_value.diff( 467 other_value, of_serialized=of_serialized 468 ) 469 if nested_diff: 470 diff_result[field_name] = nested_diff 471 # only support serializable dataclasses 472 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 473 other_value 474 ): 475 raise ValueError("Non-serializable dataclass is not supported") 476 else: 477 # get the values of either the serialized or the actual values 478 self_value_s = ser_self[field_name] if of_serialized else self_value 479 other_value_s = ser_other[field_name] if of_serialized else other_value 480 # compare the values 481 if not array_safe_eq(self_value_s, other_value_s): 482 diff_result[field_name] = {"self": self_value, "other": other_value} 483 484 # return the diff result 485 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
487 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 488 """update the instance from a nested dict, useful for configuration from command line args 489 490 # Parameters: 491 - `nested_dict : dict[str, Any]` 492 nested dict to update the instance with 493 """ 494 for field in dataclasses.fields(self): # type: ignore[arg-type] 495 field_name: str = field.name 496 self_value = getattr(self, field_name) 497 498 if field_name in nested_dict: 499 if isinstance(self_value, SerializableDataclass): 500 self_value.update_from_nested_dict(nested_dict[field_name]) 501 else: 502 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
515@functools.lru_cache(typed=True) 516def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 517 "cached typing.get_type_hints for a class" 518 return typing.get_type_hints(cls)
cached typing.get_type_hints for a class
521def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 522 "helper function to get type hints for a class" 523 cls_type_hints: dict[str, Any] 524 try: 525 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 526 if len(cls_type_hints) == 0: 527 cls_type_hints = typing.get_type_hints(cls) 528 529 if len(cls_type_hints) == 0: 530 raise ValueError(f"empty type hints for {cls.__name__ = }") 531 except (TypeError, NameError, ValueError) as e: 532 raise TypeError( 533 f"Cannot get type hints for {cls = }\n" 534 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 535 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 536 + f" {e = }" 537 ) from e 538 539 return cls_type_hints
helper function to get type hints for a class
542class KWOnlyError(NotImplementedError): 543 "kw-only dataclasses are not supported in python <3.9" 544 545 pass
kw-only dataclasses are not supported in python <3.9
Inherited Members
- builtins.NotImplementedError
- NotImplementedError
- builtins.BaseException
- with_traceback
- add_note
- args
base class for field errors
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
554class NotSerializableFieldException(FieldError): 555 "field is not a `SerializableField`" 556 557 pass
field is not a SerializableField
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while serializing a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while loading a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
572class FieldTypeMismatchError(FieldError, TypeError): 573 "error when a field type does not match the type hint" 574 575 pass
error when a field type does not match the type hint
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
578@dataclass_transform( 579 field_specifiers=(serializable_field, SerializableField), 580) 581def serializable_dataclass( 582 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 583 _cls=None, # type: ignore 584 *, 585 init: bool = True, 586 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 587 eq: bool = True, 588 order: bool = False, 589 unsafe_hash: bool = False, 590 frozen: bool = False, 591 properties_to_serialize: Optional[list[str]] = None, 592 register_handler: bool = True, 593 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 594 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 595 **kwargs, 596): 597 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 598 599 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 600 601 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 602 603 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 604 605 Examines PEP 526 `__annotations__` to determine fields. 606 607 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 608 609 ```python 610 @serializable_dataclass(kw_only=True) 611 class Myclass(SerializableDataclass): 612 a: int 613 b: str 614 ``` 615 ```python 616 >>> Myclass(a=1, b="q").serialize() 617 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 618 ``` 619 620 # Parameters: 621 - `_cls : _type_` 622 class to decorate. don't pass this arg, just use this as a decorator 623 (defaults to `None`) 624 - `init : bool` 625 (defaults to `True`) 626 - `repr : bool` 627 (defaults to `True`) 628 - `order : bool` 629 (defaults to `False`) 630 - `unsafe_hash : bool` 631 (defaults to `False`) 632 - `frozen : bool` 633 (defaults to `False`) 634 - `properties_to_serialize : Optional[list[str]]` 635 **SerializableDataclass only:** which properties to add to the serialized data dict 636 (defaults to `None`) 637 - `register_handler : bool` 638 **SerializableDataclass only:** if true, register the class with ZANJ for loading 639 (defaults to `True`) 640 - `on_typecheck_error : ErrorMode` 641 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 642 - `on_typecheck_mismatch : ErrorMode` 643 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 644 645 # Returns: 646 - `_type_` 647 the decorated class 648 649 # Raises: 650 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 651 - `NotSerializableFieldException` : if a field is not a `SerializableField` 652 - `FieldSerializationError` : if there is an error serializing a field 653 - `AttributeError` : if a property is not found on the class 654 - `FieldLoadingError` : if there is an error loading a field 655 """ 656 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 657 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 658 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 659 660 if properties_to_serialize is None: 661 _properties_to_serialize: list = list() 662 else: 663 _properties_to_serialize = properties_to_serialize 664 665 def wrap(cls: Type[T]) -> Type[T]: 666 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 667 for field_name, field_type in cls.__annotations__.items(): 668 field_value = getattr(cls, field_name, None) 669 if not isinstance(field_value, SerializableField): 670 if isinstance(field_value, dataclasses.Field): 671 # Convert the field to a SerializableField while preserving properties 672 field_value = SerializableField.from_Field(field_value) 673 else: 674 # Create a new SerializableField 675 field_value = serializable_field() 676 setattr(cls, field_name, field_value) 677 678 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 679 if sys.version_info < (3, 10): 680 if "kw_only" in kwargs: 681 if kwargs["kw_only"] == True: # noqa: E712 682 raise KWOnlyError("kw_only is not supported in python >=3.9") 683 else: 684 del kwargs["kw_only"] 685 686 # call `dataclasses.dataclass` to set some stuff up 687 cls = dataclasses.dataclass( # type: ignore[call-overload] 688 cls, 689 init=init, 690 repr=repr, 691 eq=eq, 692 order=order, 693 unsafe_hash=unsafe_hash, 694 frozen=frozen, 695 **kwargs, 696 ) 697 698 # copy these to the class 699 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 700 701 # ====================================================================== 702 # define `serialize` func 703 # done locally since it depends on args to the decorator 704 # ====================================================================== 705 def serialize(self) -> dict[str, Any]: 706 result: dict[str, Any] = { 707 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 708 } 709 # for each field in the class 710 for field in dataclasses.fields(self): # type: ignore[arg-type] 711 # need it to be our special SerializableField 712 if not isinstance(field, SerializableField): 713 raise NotSerializableFieldException( 714 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 715 f"but a {type(field)} " 716 "this state should be inaccessible, please report this bug!" 717 ) 718 719 # try to save it 720 if field.serialize: 721 try: 722 # get the val 723 value = getattr(self, field.name) 724 # if it is a serializable dataclass, serialize it 725 if isinstance(value, SerializableDataclass): 726 value = value.serialize() 727 # if the value has a serialization function, use that 728 if hasattr(value, "serialize") and callable(value.serialize): 729 value = value.serialize() 730 # if the field has a serialization function, use that 731 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 732 elif field.serialization_fn: 733 value = field.serialization_fn(value) 734 735 # store the value in the result 736 result[field.name] = value 737 except Exception as e: 738 raise FieldSerializationError( 739 "\n".join( 740 [ 741 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 742 f"{field = }", 743 f"{value = }", 744 f"{self = }", 745 ] 746 ) 747 ) from e 748 749 # store each property if we can get it 750 for prop in self._properties_to_serialize: 751 if hasattr(cls, prop): 752 value = getattr(self, prop) 753 result[prop] = value 754 else: 755 raise AttributeError( 756 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 757 + f"but it is in {self._properties_to_serialize = }" 758 + f"\n{self = }" 759 ) 760 761 return result 762 763 # ====================================================================== 764 # define `load` func 765 # done locally since it depends on args to the decorator 766 # ====================================================================== 767 # mypy thinks this isnt a classmethod 768 @classmethod # type: ignore[misc] 769 def load(cls, data: dict[str, Any] | T) -> Type[T]: 770 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 771 if isinstance(data, cls): 772 return data 773 774 assert isinstance( 775 data, typing.Mapping 776 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 777 778 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 779 780 # initialize dict for keeping what we will pass to the constructor 781 ctor_kwargs: dict[str, Any] = dict() 782 783 # iterate over the fields of the class 784 for field in dataclasses.fields(cls): 785 # check if the field is a SerializableField 786 assert isinstance( 787 field, SerializableField 788 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 789 790 # check if the field is in the data and if it should be initialized 791 if (field.name in data) and field.init: 792 # get the value, we will be processing it 793 value: Any = data[field.name] 794 795 # get the type hint for the field 796 field_type_hint: Any = cls_type_hints.get(field.name, None) 797 798 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 799 if field.deserialize_fn: 800 # if it has a deserialization function, use that 801 value = field.deserialize_fn(value) 802 elif field.loading_fn: 803 # if it has a loading function, use that 804 value = field.loading_fn(data) 805 elif ( 806 field_type_hint is not None 807 and hasattr(field_type_hint, "load") 808 and callable(field_type_hint.load) 809 ): 810 # if no loading function but has a type hint with a load method, use that 811 if isinstance(value, dict): 812 value = field_type_hint.load(value) 813 else: 814 raise FieldLoadingError( 815 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 816 ) 817 else: 818 # assume no loading needs to happen, keep `value` as-is 819 pass 820 821 # store the value in the constructor kwargs 822 ctor_kwargs[field.name] = value 823 824 # create a new instance of the class with the constructor kwargs 825 output: cls = cls(**ctor_kwargs) 826 827 # validate the types of the fields if needed 828 if on_typecheck_mismatch != ErrorMode.IGNORE: 829 fields_valid: dict[str, bool] = ( 830 SerializableDataclass__validate_fields_types__dict( 831 output, 832 on_typecheck_error=on_typecheck_error, 833 ) 834 ) 835 836 # if there are any fields that are not valid, raise an error 837 if not all(fields_valid.values()): 838 msg: str = ( 839 f"Type mismatch in fields of {cls.__name__}:\n" 840 + "\n".join( 841 [ 842 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 843 for k, v in fields_valid.items() 844 if not v 845 ] 846 ) 847 ) 848 849 on_typecheck_mismatch.process( 850 msg, except_cls=FieldTypeMismatchError 851 ) 852 853 # return the new instance 854 return output 855 856 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 857 # type is `Callable[[T], dict]` 858 cls.serialize = serialize # type: ignore[attr-defined] 859 # type is `Callable[[dict], T]` 860 cls.load = load # type: ignore[attr-defined] 861 # type is `Callable[[T, ErrorMode], bool]` 862 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 863 864 # type is `Callable[[T, T], bool]` 865 if not hasattr(cls, "__eq__"): 866 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 867 868 # Register the class with ZANJ 869 if register_handler: 870 zanj_register_loader_serializable_dataclass(cls) 871 872 return cls 873 874 if _cls is None: 875 return wrap 876 else: 877 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
(defaults toTrue
)repr : bool
(defaults toTrue
)order : bool
(defaults toFalse
)unsafe_hash : bool
(defaults toFalse
)frozen : bool
(defaults toFalse
)properties_to_serialize : Optional[list[str]]
SerializableDataclass only: which properties to add to the serialized data dict (defaults toNone
)register_handler : bool
SerializableDataclass only: if true, register the class with ZANJ for loading (defaults toTrue
)on_typecheck_error : ErrorMode
SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return falseon_typecheck_mismatch : ErrorMode
SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field