Coverage for muutils\json_serialize\serializable_dataclass.py: 54%

247 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-05 19:24 -0700

1"""save and load objects to and from json or compatible formats in a recoverable way 

2 

3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 

4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 

5 

6Instead, you define your class: 

7 

8```python 

9@serializable_dataclass 

10class MyClass(SerializableDataclass): 

11 a: int 

12 b: str 

13``` 

14 

15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

16 

17 >>> my_obj = MyClass(a=1, b="q") 

18 >>> s = json.dumps(my_obj.serialize()) 

19 >>> s 

20 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

21 >>> read_obj = MyClass.load(json.loads(s)) 

22 >>> read_obj == my_obj 

23 True 

24 

25This isn't too impressive on its own, but it gets more useful when you have nested classses, 

26or fields that are not json-serializable by default: 

27 

28```python 

29@serializable_dataclass 

30class NestedClass(SerializableDataclass): 

31 x: str 

32 y: MyClass 

33 act_fun: torch.nn.Module = serializable_field( 

34 default=torch.nn.ReLU(), 

35 serialization_fn=lambda x: str(x), 

36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

37 ) 

38``` 

39 

40which gives us: 

41 

42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

43 >>> s = json.dumps(nc.serialize()) 

44 >>> s 

45 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

46 >>> read_nc = NestedClass.load(json.loads(s)) 

47 >>> read_nc == nc 

48 True 

49 

50""" 

51 

52from __future__ import annotations 

53 

54import abc 

55import dataclasses 

56import functools 

57import json 

58import sys 

59import typing 

60import warnings 

61from typing import Any, Optional, Type, TypeVar 

62 

63from muutils.errormode import ErrorMode 

64from muutils.validate_type import validate_type 

65from muutils.json_serialize.serializable_field import ( 

66 SerializableField, 

67 serializable_field, 

68) 

69from muutils.json_serialize.util import array_safe_eq, dc_eq 

70 

71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 

72 

73 

74def _dataclass_transform_mock( 

75 *, 

76 eq_default: bool = True, 

77 order_default: bool = False, 

78 kw_only_default: bool = False, 

79 frozen_default: bool = False, 

80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (), 

81 **kwargs: Any, 

82) -> typing.Callable: 

83 "mock `typing.dataclass_transform` for python <3.11" 

84 

85 def decorator(cls_or_fn): 

86 cls_or_fn.__dataclass_transform__ = { 

87 "eq_default": eq_default, 

88 "order_default": order_default, 

89 "kw_only_default": kw_only_default, 

90 "frozen_default": frozen_default, 

91 "field_specifiers": field_specifiers, 

92 "kwargs": kwargs, 

93 } 

94 return cls_or_fn 

95 

96 return decorator 

97 

98 

99dataclass_transform: typing.Callable 

100if sys.version_info < (3, 11): 

101 dataclass_transform = _dataclass_transform_mock 

102else: 

103 dataclass_transform = typing.dataclass_transform 

104 

105 

106T = TypeVar("T") 

107 

108 

109class CantGetTypeHintsWarning(UserWarning): 

110 "special warning for when we can't get type hints" 

111 

112 pass 

113 

114 

115class ZanjMissingWarning(UserWarning): 

116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 

117 

118 pass 

119 

120 

121_zanj_loading_needs_import: bool = True 

122"flag to keep track of if we have successfully imported ZANJ" 

123 

124 

125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 

126 """Register a serializable dataclass with the ZANJ import 

127 

128 this allows `ZANJ().read()` to load the class and not just return plain dicts 

129 

130 

131 # TODO: there is some duplication here with register_loader_handler 

132 """ 

133 global _zanj_loading_needs_import 

134 

135 if _zanj_loading_needs_import: 

136 try: 

137 from zanj.loading import ( # type: ignore[import] 

138 LoaderHandler, 

139 register_loader_handler, 

140 ) 

141 except ImportError: 

142 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 

143 # warnings.warn( 

144 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 

145 # ZanjMissingWarning, 

146 # ) 

147 return 

148 

149 _format: str = f"{cls.__name__}(SerializableDataclass)" 

150 lh: LoaderHandler = LoaderHandler( 

151 check=lambda json_item, path=None, z=None: ( # type: ignore 

152 isinstance(json_item, dict) 

153 and "__format__" in json_item 

154 and json_item["__format__"].startswith(_format) 

155 ), 

156 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 

157 uid=_format, 

158 source_pckg=cls.__module__, 

159 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 

160 ) 

161 

162 register_loader_handler(lh) 

163 

164 return lh 

165 

166 

167_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 

168_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 

169 

170 

171class FieldIsNotInitOrSerializeWarning(UserWarning): 

172 pass 

173 

174 

175def SerializableDataclass__validate_field_type( 

176 self: SerializableDataclass, 

177 field: SerializableField | str, 

178 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

179) -> bool: 

180 """given a dataclass, check the field matches the type hint 

181 

182 this function is written to `SerializableDataclass.validate_field_type` 

183 

184 # Parameters: 

185 - `self : SerializableDataclass` 

186 `SerializableDataclass` instance 

187 - `field : SerializableField | str` 

188 field to validate, will get from `self.__dataclass_fields__` if an `str` 

189 - `on_typecheck_error : ErrorMode` 

190 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 

191 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 

192 

193 # Returns: 

194 - `bool` 

195 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 

196 """ 

197 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

198 

199 # get field 

200 _field: SerializableField 

201 if isinstance(field, str): 

202 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 

203 else: 

204 _field = field 

205 

206 # do nothing case 

207 if not _field.assert_type: 

208 return True 

209 

210 # if field is not `init` or not `serialize`, skip but warn 

211 # TODO: how to handle fields which are not `init` or `serialize`? 

212 if not _field.init or not _field.serialize: 

213 warnings.warn( 

214 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 

215 FieldIsNotInitOrSerializeWarning, 

216 ) 

217 return True 

218 

219 assert isinstance( 

220 _field, SerializableField 

221 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 

222 

223 # get field type hints 

224 try: 

225 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 

226 except KeyError as e: 

227 on_typecheck_error.process( 

228 ( 

229 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 

230 + f"{get_cls_type_hints(self.__class__) = }\n" 

231 + f"Python version is {sys.version_info = }. You can:\n" 

232 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 

233 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 

234 + " - use python 3.9.x or higher\n" 

235 + " - specify custom type validation function via `custom_typecheck_fn`\n" 

236 ), 

237 except_cls=TypeError, 

238 except_from=e, 

239 ) 

240 return False 

241 

242 # get the value 

243 value: Any = getattr(self, _field.name) 

244 

245 # validate the type 

246 try: 

247 type_is_valid: bool 

248 # validate the type with the default type validator 

249 if _field.custom_typecheck_fn is None: 

250 type_is_valid = validate_type(value, field_type_hint) 

251 # validate the type with a custom type validator 

252 else: 

253 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 

254 

255 return type_is_valid 

256 

257 except Exception as e: 

258 on_typecheck_error.process( 

259 "exception while validating type: " 

260 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 

261 except_cls=ValueError, 

262 except_from=e, 

263 ) 

264 return False 

265 

266 

267def SerializableDataclass__validate_fields_types__dict( 

268 self: SerializableDataclass, 

269 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

270) -> dict[str, bool]: 

271 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 

272 

273 returns a dict of field names to bools, where the bool is if the field type is valid 

274 """ 

275 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

276 

277 # if except, bundle the exceptions 

278 results: dict[str, bool] = dict() 

279 exceptions: dict[str, Exception] = dict() 

280 

281 # for each field in the class 

282 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 

283 for field in cls_fields: 

284 try: 

285 results[field.name] = self.validate_field_type(field, on_typecheck_error) 

286 except Exception as e: 

287 results[field.name] = False 

288 exceptions[field.name] = e 

289 

290 # figure out what to do with the exceptions 

291 if len(exceptions) > 0: 

292 on_typecheck_error.process( 

293 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 

294 + "\n\t" 

295 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 

296 except_cls=ValueError, 

297 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 

298 except_from=list(exceptions.values())[0], 

299 ) 

300 

301 return results 

302 

303 

304def SerializableDataclass__validate_fields_types( 

305 self: SerializableDataclass, 

306 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

307) -> bool: 

308 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

309 return all( 

310 SerializableDataclass__validate_fields_types__dict( 

311 self, on_typecheck_error=on_typecheck_error 

312 ).values() 

313 ) 

314 

315 

316@dataclass_transform( 

317 field_specifiers=(serializable_field, SerializableField), 

318) 

319class SerializableDataclass(abc.ABC): 

320 """Base class for serializable dataclasses 

321 

322 only for linting and type checking, still need to call `serializable_dataclass` decorator 

323 

324 # Usage: 

325 

326 ```python 

327 @serializable_dataclass 

328 class MyClass(SerializableDataclass): 

329 a: int 

330 b: str 

331 ``` 

332 

333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

334 

335 >>> my_obj = MyClass(a=1, b="q") 

336 >>> s = json.dumps(my_obj.serialize()) 

337 >>> s 

338 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

339 >>> read_obj = MyClass.load(json.loads(s)) 

340 >>> read_obj == my_obj 

341 True 

342 

343 This isn't too impressive on its own, but it gets more useful when you have nested classses, 

344 or fields that are not json-serializable by default: 

345 

346 ```python 

347 @serializable_dataclass 

348 class NestedClass(SerializableDataclass): 

349 x: str 

350 y: MyClass 

351 act_fun: torch.nn.Module = serializable_field( 

352 default=torch.nn.ReLU(), 

353 serialization_fn=lambda x: str(x), 

354 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

355 ) 

356 ``` 

357 

358 which gives us: 

359 

360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

361 >>> s = json.dumps(nc.serialize()) 

362 >>> s 

363 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

364 >>> read_nc = NestedClass.load(json.loads(s)) 

365 >>> read_nc == nc 

366 True 

367 """ 

368 

369 def serialize(self) -> dict[str, Any]: 

370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 

371 raise NotImplementedError( 

372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 

373 ) 

374 

375 @classmethod 

376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 

377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 

378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 

379 

380 def validate_fields_types( 

381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 

382 ) -> bool: 

383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

384 return SerializableDataclass__validate_fields_types( 

385 self, on_typecheck_error=on_typecheck_error 

386 ) 

387 

388 def validate_field_type( 

389 self, 

390 field: "SerializableField|str", 

391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

392 ) -> bool: 

393 """given a dataclass, check the field matches the type hint""" 

394 return SerializableDataclass__validate_field_type( 

395 self, field, on_typecheck_error=on_typecheck_error 

396 ) 

397 

398 def __eq__(self, other: Any) -> bool: 

399 return dc_eq(self, other) 

400 

401 def __hash__(self) -> int: 

402 "hashes the json-serialized representation of the class" 

403 return hash(json.dumps(self.serialize())) 

404 

405 def diff( 

406 self, other: "SerializableDataclass", of_serialized: bool = False 

407 ) -> dict[str, Any]: 

408 """get a rich and recursive diff between two instances of a serializable dataclass 

409 

410 ```python 

411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 

412 {'b': {'self': 2, 'other': 3}} 

413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 

414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 

415 ``` 

416 

417 # Parameters: 

418 - `other : SerializableDataclass` 

419 other instance to compare against 

420 - `of_serialized : bool` 

421 if true, compare serialized data and not raw values 

422 (defaults to `False`) 

423 

424 # Returns: 

425 - `dict[str, Any]` 

426 

427 

428 # Raises: 

429 - `ValueError` : if the instances are not of the same type 

430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 

431 """ 

432 # match types 

433 if type(self) is not type(other): 

434 raise ValueError( 

435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 

436 ) 

437 

438 # initialize the diff result 

439 diff_result: dict = {} 

440 

441 # if they are the same, return the empty diff 

442 if self == other: 

443 return diff_result 

444 

445 # if we are working with serialized data, serialize the instances 

446 if of_serialized: 

447 ser_self: dict = self.serialize() 

448 ser_other: dict = other.serialize() 

449 

450 # for each field in the class 

451 for field in dataclasses.fields(self): # type: ignore[arg-type] 

452 # skip fields that are not for comparison 

453 if not field.compare: 

454 continue 

455 

456 # get values 

457 field_name: str = field.name 

458 self_value = getattr(self, field_name) 

459 other_value = getattr(other, field_name) 

460 

461 # if the values are both serializable dataclasses, recurse 

462 if isinstance(self_value, SerializableDataclass) and isinstance( 

463 other_value, SerializableDataclass 

464 ): 

465 nested_diff: dict = self_value.diff( 

466 other_value, of_serialized=of_serialized 

467 ) 

468 if nested_diff: 

469 diff_result[field_name] = nested_diff 

470 # only support serializable dataclasses 

471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 

472 other_value 

473 ): 

474 raise ValueError("Non-serializable dataclass is not supported") 

475 else: 

476 # get the values of either the serialized or the actual values 

477 self_value_s = ser_self[field_name] if of_serialized else self_value 

478 other_value_s = ser_other[field_name] if of_serialized else other_value 

479 # compare the values 

480 if not array_safe_eq(self_value_s, other_value_s): 

481 diff_result[field_name] = {"self": self_value, "other": other_value} 

482 

483 # return the diff result 

484 return diff_result 

485 

486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 

487 """update the instance from a nested dict, useful for configuration from command line args 

488 

489 # Parameters: 

490 - `nested_dict : dict[str, Any]` 

491 nested dict to update the instance with 

492 """ 

493 for field in dataclasses.fields(self): # type: ignore[arg-type] 

494 field_name: str = field.name 

495 self_value = getattr(self, field_name) 

496 

497 if field_name in nested_dict: 

498 if isinstance(self_value, SerializableDataclass): 

499 self_value.update_from_nested_dict(nested_dict[field_name]) 

500 else: 

501 setattr(self, field_name, nested_dict[field_name]) 

502 

503 def __copy__(self) -> "SerializableDataclass": 

504 "deep copy by serializing and loading the instance to json" 

505 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

506 

507 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 

508 "deep copy by serializing and loading the instance to json" 

509 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

510 

511 

512# cache this so we don't have to keep getting it 

513# TODO: are the types hashable? does this even make sense? 

514@functools.lru_cache(typed=True) 

515def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 

516 "cached typing.get_type_hints for a class" 

517 return typing.get_type_hints(cls) 

518 

519 

520def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 

521 "helper function to get type hints for a class" 

522 cls_type_hints: dict[str, Any] 

523 try: 

524 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 

525 if len(cls_type_hints) == 0: 

526 cls_type_hints = typing.get_type_hints(cls) 

527 

528 if len(cls_type_hints) == 0: 

529 raise ValueError(f"empty type hints for {cls.__name__ = }") 

530 except (TypeError, NameError, ValueError) as e: 

531 raise TypeError( 

532 f"Cannot get type hints for {cls = }\n" 

533 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 

534 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 

535 + f" {e = }" 

536 ) from e 

537 

538 return cls_type_hints 

539 

540 

541class KWOnlyError(NotImplementedError): 

542 "kw-only dataclasses are not supported in python <3.9" 

543 

544 pass 

545 

546 

547class FieldError(ValueError): 

548 "base class for field errors" 

549 

550 pass 

551 

552 

553class NotSerializableFieldException(FieldError): 

554 "field is not a `SerializableField`" 

555 

556 pass 

557 

558 

559class FieldSerializationError(FieldError): 

560 "error while serializing a field" 

561 

562 pass 

563 

564 

565class FieldLoadingError(FieldError): 

566 "error while loading a field" 

567 

568 pass 

569 

570 

571class FieldTypeMismatchError(FieldError, TypeError): 

572 "error when a field type does not match the type hint" 

573 

574 pass 

575 

576 

577@dataclass_transform( 

578 field_specifiers=(serializable_field, SerializableField), 

579) 

580def serializable_dataclass( 

581 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 

582 _cls=None, # type: ignore 

583 *, 

584 init: bool = True, 

585 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 

586 eq: bool = True, 

587 order: bool = False, 

588 unsafe_hash: bool = False, 

589 frozen: bool = False, 

590 properties_to_serialize: Optional[list[str]] = None, 

591 register_handler: bool = True, 

592 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

593 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 

594 **kwargs, 

595): 

596 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 

597 

598 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 

599 

600 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 

601 

602 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 

603 

604 Examines PEP 526 `__annotations__` to determine fields. 

605 

606 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 

607 

608 ```python 

609 @serializable_dataclass(kw_only=True) 

610 class Myclass(SerializableDataclass): 

611 a: int 

612 b: str 

613 ``` 

614 ```python 

615 >>> Myclass(a=1, b="q").serialize() 

616 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 

617 ``` 

618 

619 # Parameters: 

620 - `_cls : _type_` 

621 class to decorate. don't pass this arg, just use this as a decorator 

622 (defaults to `None`) 

623 - `init : bool` 

624 (defaults to `True`) 

625 - `repr : bool` 

626 (defaults to `True`) 

627 - `order : bool` 

628 (defaults to `False`) 

629 - `unsafe_hash : bool` 

630 (defaults to `False`) 

631 - `frozen : bool` 

632 (defaults to `False`) 

633 - `properties_to_serialize : Optional[list[str]]` 

634 **SerializableDataclass only:** which properties to add to the serialized data dict 

635 (defaults to `None`) 

636 - `register_handler : bool` 

637 **SerializableDataclass only:** if true, register the class with ZANJ for loading 

638 (defaults to `True`) 

639 - `on_typecheck_error : ErrorMode` 

640 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 

641 - `on_typecheck_mismatch : ErrorMode` 

642 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 

643 

644 # Returns: 

645 - `_type_` 

646 the decorated class 

647 

648 # Raises: 

649 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 

650 - `NotSerializableFieldException` : if a field is not a `SerializableField` 

651 - `FieldSerializationError` : if there is an error serializing a field 

652 - `AttributeError` : if a property is not found on the class 

653 - `FieldLoadingError` : if there is an error loading a field 

654 """ 

655 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 

656 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

657 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 

658 

659 if properties_to_serialize is None: 

660 _properties_to_serialize: list = list() 

661 else: 

662 _properties_to_serialize = properties_to_serialize 

663 

664 def wrap(cls: Type[T]) -> Type[T]: 

665 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 

666 for field_name, field_type in cls.__annotations__.items(): 

667 field_value = getattr(cls, field_name, None) 

668 if not isinstance(field_value, SerializableField): 

669 if isinstance(field_value, dataclasses.Field): 

670 # Convert the field to a SerializableField while preserving properties 

671 field_value = SerializableField.from_Field(field_value) 

672 else: 

673 # Create a new SerializableField 

674 field_value = serializable_field() 

675 setattr(cls, field_name, field_value) 

676 

677 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 

678 if sys.version_info < (3, 10): 

679 if "kw_only" in kwargs: 

680 if kwargs["kw_only"] == True: # noqa: E712 

681 raise KWOnlyError("kw_only is not supported in python >=3.9") 

682 else: 

683 del kwargs["kw_only"] 

684 

685 # call `dataclasses.dataclass` to set some stuff up 

686 cls = dataclasses.dataclass( # type: ignore[call-overload] 

687 cls, 

688 init=init, 

689 repr=repr, 

690 eq=eq, 

691 order=order, 

692 unsafe_hash=unsafe_hash, 

693 frozen=frozen, 

694 **kwargs, 

695 ) 

696 

697 # copy these to the class 

698 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 

699 

700 # ====================================================================== 

701 # define `serialize` func 

702 # done locally since it depends on args to the decorator 

703 # ====================================================================== 

704 def serialize(self) -> dict[str, Any]: 

705 result: dict[str, Any] = { 

706 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 

707 } 

708 # for each field in the class 

709 for field in dataclasses.fields(self): # type: ignore[arg-type] 

710 # need it to be our special SerializableField 

711 if not isinstance(field, SerializableField): 

712 raise NotSerializableFieldException( 

713 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 

714 f"but a {type(field)} " 

715 "this state should be inaccessible, please report this bug!" 

716 ) 

717 

718 # try to save it 

719 if field.serialize: 

720 try: 

721 # get the val 

722 value = getattr(self, field.name) 

723 # if it is a serializable dataclass, serialize it 

724 if isinstance(value, SerializableDataclass): 

725 value = value.serialize() 

726 # if the value has a serialization function, use that 

727 if hasattr(value, "serialize") and callable(value.serialize): 

728 value = value.serialize() 

729 # if the field has a serialization function, use that 

730 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 

731 elif field.serialization_fn: 

732 value = field.serialization_fn(value) 

733 

734 # store the value in the result 

735 result[field.name] = value 

736 except Exception as e: 

737 raise FieldSerializationError( 

738 "\n".join( 

739 [ 

740 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 

741 f"{field = }", 

742 f"{value = }", 

743 f"{self = }", 

744 ] 

745 ) 

746 ) from e 

747 

748 # store each property if we can get it 

749 for prop in self._properties_to_serialize: 

750 if hasattr(cls, prop): 

751 value = getattr(self, prop) 

752 result[prop] = value 

753 else: 

754 raise AttributeError( 

755 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 

756 + f"but it is in {self._properties_to_serialize = }" 

757 + f"\n{self = }" 

758 ) 

759 

760 return result 

761 

762 # ====================================================================== 

763 # define `load` func 

764 # done locally since it depends on args to the decorator 

765 # ====================================================================== 

766 # mypy thinks this isnt a classmethod 

767 @classmethod # type: ignore[misc] 

768 def load(cls, data: dict[str, Any] | T) -> Type[T]: 

769 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 

770 if isinstance(data, cls): 

771 return data 

772 

773 assert isinstance( 

774 data, typing.Mapping 

775 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 

776 

777 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 

778 

779 # initialize dict for keeping what we will pass to the constructor 

780 ctor_kwargs: dict[str, Any] = dict() 

781 

782 # iterate over the fields of the class 

783 for field in dataclasses.fields(cls): 

784 # check if the field is a SerializableField 

785 assert isinstance( 

786 field, SerializableField 

787 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 

788 

789 # check if the field is in the data and if it should be initialized 

790 if (field.name in data) and field.init: 

791 # get the value, we will be processing it 

792 value: Any = data[field.name] 

793 

794 # get the type hint for the field 

795 field_type_hint: Any = cls_type_hints.get(field.name, None) 

796 

797 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 

798 if field.deserialize_fn: 

799 # if it has a deserialization function, use that 

800 value = field.deserialize_fn(value) 

801 elif field.loading_fn: 

802 # if it has a loading function, use that 

803 value = field.loading_fn(data) 

804 elif ( 

805 field_type_hint is not None 

806 and hasattr(field_type_hint, "load") 

807 and callable(field_type_hint.load) 

808 ): 

809 # if no loading function but has a type hint with a load method, use that 

810 if isinstance(value, dict): 

811 value = field_type_hint.load(value) 

812 else: 

813 raise FieldLoadingError( 

814 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 

815 ) 

816 else: 

817 # assume no loading needs to happen, keep `value` as-is 

818 pass 

819 

820 # store the value in the constructor kwargs 

821 ctor_kwargs[field.name] = value 

822 

823 # create a new instance of the class with the constructor kwargs 

824 output: cls = cls(**ctor_kwargs) 

825 

826 # validate the types of the fields if needed 

827 if on_typecheck_mismatch != ErrorMode.IGNORE: 

828 fields_valid: dict[str, bool] = ( 

829 SerializableDataclass__validate_fields_types__dict( 

830 output, 

831 on_typecheck_error=on_typecheck_error, 

832 ) 

833 ) 

834 

835 # if there are any fields that are not valid, raise an error 

836 if not all(fields_valid.values()): 

837 msg: str = ( 

838 f"Type mismatch in fields of {cls.__name__}:\n" 

839 + "\n".join( 

840 [ 

841 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 

842 for k, v in fields_valid.items() 

843 if not v 

844 ] 

845 ) 

846 ) 

847 

848 on_typecheck_mismatch.process( 

849 msg, except_cls=FieldTypeMismatchError 

850 ) 

851 

852 # return the new instance 

853 return output 

854 

855 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 

856 # type is `Callable[[T], dict]` 

857 cls.serialize = serialize # type: ignore[attr-defined] 

858 # type is `Callable[[dict], T]` 

859 cls.load = load # type: ignore[attr-defined] 

860 # type is `Callable[[T, ErrorMode], bool]` 

861 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 

862 

863 # type is `Callable[[T, T], bool]` 

864 if not hasattr(cls, "__eq__"): 

865 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 

866 

867 # Register the class with ZANJ 

868 if register_handler: 

869 zanj_register_loader_serializable_dataclass(cls) 

870 

871 return cls 

872 

873 if _cls is None: 

874 return wrap 

875 else: 

876 return wrap(_cls)