Coverage for muutils\json_serialize\serializable_dataclass.py: 56%

257 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-14 01:33 -0700

1"""save and load objects to and from json or compatible formats in a recoverable way 

2 

3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 

4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 

5 

6Instead, you define your class: 

7 

8```python 

9@serializable_dataclass 

10class MyClass(SerializableDataclass): 

11 a: int 

12 b: str 

13``` 

14 

15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

16 

17 >>> my_obj = MyClass(a=1, b="q") 

18 >>> s = json.dumps(my_obj.serialize()) 

19 >>> s 

20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

21 >>> read_obj = MyClass.load(json.loads(s)) 

22 >>> read_obj == my_obj 

23 True 

24 

25This isn't too impressive on its own, but it gets more useful when you have nested classses, 

26or fields that are not json-serializable by default: 

27 

28```python 

29@serializable_dataclass 

30class NestedClass(SerializableDataclass): 

31 x: str 

32 y: MyClass 

33 act_fun: torch.nn.Module = serializable_field( 

34 default=torch.nn.ReLU(), 

35 serialization_fn=lambda x: str(x), 

36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

37 ) 

38``` 

39 

40which gives us: 

41 

42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

43 >>> s = json.dumps(nc.serialize()) 

44 >>> s 

45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

46 >>> read_nc = NestedClass.load(json.loads(s)) 

47 >>> read_nc == nc 

48 True 

49 

50""" 

51 

52from __future__ import annotations 

53 

54import abc 

55import dataclasses 

56import functools 

57import json 

58import sys 

59import typing 

60import warnings 

61from typing import Any, Optional, Type, TypeVar 

62 

63from muutils.errormode import ErrorMode 

64from muutils.validate_type import validate_type 

65from muutils.json_serialize.serializable_field import ( 

66 SerializableField, 

67 serializable_field, 

68) 

69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq 

70 

71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 

72 

73 

74def _dataclass_transform_mock( 

75 *, 

76 eq_default: bool = True, 

77 order_default: bool = False, 

78 kw_only_default: bool = False, 

79 frozen_default: bool = False, 

80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (), 

81 **kwargs: Any, 

82) -> typing.Callable: 

83 "mock `typing.dataclass_transform` for python <3.11" 

84 

85 def decorator(cls_or_fn): 

86 cls_or_fn.__dataclass_transform__ = { 

87 "eq_default": eq_default, 

88 "order_default": order_default, 

89 "kw_only_default": kw_only_default, 

90 "frozen_default": frozen_default, 

91 "field_specifiers": field_specifiers, 

92 "kwargs": kwargs, 

93 } 

94 return cls_or_fn 

95 

96 return decorator 

97 

98 

99if sys.version_info < (3, 11): 

100 dataclass_transform = _dataclass_transform_mock 

101else: 

102 dataclass_transform = typing.dataclass_transform 

103 

104 

105T = TypeVar("T") 

106 

107 

108class CantGetTypeHintsWarning(UserWarning): 

109 "special warning for when we can't get type hints" 

110 

111 pass 

112 

113 

114class ZanjMissingWarning(UserWarning): 

115 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 

116 

117 pass 

118 

119 

120_zanj_loading_needs_import: bool = True 

121"flag to keep track of if we have successfully imported ZANJ" 

122 

123 

124def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 

125 """Register a serializable dataclass with the ZANJ import 

126 

127 this allows `ZANJ().read()` to load the class and not just return plain dicts 

128 

129 

130 # TODO: there is some duplication here with register_loader_handler 

131 """ 

132 global _zanj_loading_needs_import 

133 

134 if _zanj_loading_needs_import: 

135 try: 

136 from zanj.loading import ( # type: ignore[import] 

137 LoaderHandler, 

138 register_loader_handler, 

139 ) 

140 except ImportError: 

141 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 

142 # warnings.warn( 

143 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 

144 # ZanjMissingWarning, 

145 # ) 

146 return 

147 

148 _format: str = f"{cls.__name__}(SerializableDataclass)" 

149 lh: LoaderHandler = LoaderHandler( 

150 check=lambda json_item, path=None, z=None: ( # type: ignore 

151 isinstance(json_item, dict) 

152 and _FORMAT_KEY in json_item 

153 and json_item[_FORMAT_KEY].startswith(_format) 

154 ), 

155 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 

156 uid=_format, 

157 source_pckg=cls.__module__, 

158 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 

159 ) 

160 

161 register_loader_handler(lh) 

162 

163 return lh 

164 

165 

166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 

167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 

168 

169 

170class FieldIsNotInitOrSerializeWarning(UserWarning): 

171 pass 

172 

173 

174def SerializableDataclass__validate_field_type( 

175 self: SerializableDataclass, 

176 field: SerializableField | str, 

177 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

178) -> bool: 

179 """given a dataclass, check the field matches the type hint 

180 

181 this function is written to `SerializableDataclass.validate_field_type` 

182 

183 # Parameters: 

184 - `self : SerializableDataclass` 

185 `SerializableDataclass` instance 

186 - `field : SerializableField | str` 

187 field to validate, will get from `self.__dataclass_fields__` if an `str` 

188 - `on_typecheck_error : ErrorMode` 

189 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 

190 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 

191 

192 # Returns: 

193 - `bool` 

194 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 

195 """ 

196 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

197 

198 # get field 

199 _field: SerializableField 

200 if isinstance(field, str): 

201 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 

202 else: 

203 _field = field 

204 

205 # do nothing case 

206 if not _field.assert_type: 

207 return True 

208 

209 # if field is not `init` or not `serialize`, skip but warn 

210 # TODO: how to handle fields which are not `init` or `serialize`? 

211 if not _field.init or not _field.serialize: 

212 warnings.warn( 

213 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 

214 FieldIsNotInitOrSerializeWarning, 

215 ) 

216 return True 

217 

218 assert isinstance( 

219 _field, SerializableField 

220 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 

221 

222 # get field type hints 

223 try: 

224 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 

225 except KeyError as e: 

226 on_typecheck_error.process( 

227 ( 

228 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 

229 + f"{get_cls_type_hints(self.__class__) = }\n" 

230 + f"Python version is {sys.version_info = }. You can:\n" 

231 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 

232 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 

233 + " - use python 3.9.x or higher\n" 

234 + " - specify custom type validation function via `custom_typecheck_fn`\n" 

235 ), 

236 except_cls=TypeError, 

237 except_from=e, 

238 ) 

239 return False 

240 

241 # get the value 

242 value: Any = getattr(self, _field.name) 

243 

244 # validate the type 

245 try: 

246 type_is_valid: bool 

247 # validate the type with the default type validator 

248 if _field.custom_typecheck_fn is None: 

249 type_is_valid = validate_type(value, field_type_hint) 

250 # validate the type with a custom type validator 

251 else: 

252 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 

253 

254 return type_is_valid 

255 

256 except Exception as e: 

257 on_typecheck_error.process( 

258 "exception while validating type: " 

259 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 

260 except_cls=ValueError, 

261 except_from=e, 

262 ) 

263 return False 

264 

265 

266def SerializableDataclass__validate_fields_types__dict( 

267 self: SerializableDataclass, 

268 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

269) -> dict[str, bool]: 

270 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 

271 

272 returns a dict of field names to bools, where the bool is if the field type is valid 

273 """ 

274 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

275 

276 # if except, bundle the exceptions 

277 results: dict[str, bool] = dict() 

278 exceptions: dict[str, Exception] = dict() 

279 

280 # for each field in the class 

281 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 

282 for field in cls_fields: 

283 try: 

284 results[field.name] = self.validate_field_type(field, on_typecheck_error) 

285 except Exception as e: 

286 results[field.name] = False 

287 exceptions[field.name] = e 

288 

289 # figure out what to do with the exceptions 

290 if len(exceptions) > 0: 

291 on_typecheck_error.process( 

292 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 

293 + "\n\t" 

294 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 

295 except_cls=ValueError, 

296 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 

297 except_from=list(exceptions.values())[0], 

298 ) 

299 

300 return results 

301 

302 

303def SerializableDataclass__validate_fields_types( 

304 self: SerializableDataclass, 

305 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

306) -> bool: 

307 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

308 return all( 

309 SerializableDataclass__validate_fields_types__dict( 

310 self, on_typecheck_error=on_typecheck_error 

311 ).values() 

312 ) 

313 

314 

315@dataclass_transform( 

316 field_specifiers=(serializable_field, SerializableField), 

317) 

318class SerializableDataclass(abc.ABC): 

319 """Base class for serializable dataclasses 

320 

321 only for linting and type checking, still need to call `serializable_dataclass` decorator 

322 

323 # Usage: 

324 

325 ```python 

326 @serializable_dataclass 

327 class MyClass(SerializableDataclass): 

328 a: int 

329 b: str 

330 ``` 

331 

332 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

333 

334 >>> my_obj = MyClass(a=1, b="q") 

335 >>> s = json.dumps(my_obj.serialize()) 

336 >>> s 

337 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

338 >>> read_obj = MyClass.load(json.loads(s)) 

339 >>> read_obj == my_obj 

340 True 

341 

342 This isn't too impressive on its own, but it gets more useful when you have nested classses, 

343 or fields that are not json-serializable by default: 

344 

345 ```python 

346 @serializable_dataclass 

347 class NestedClass(SerializableDataclass): 

348 x: str 

349 y: MyClass 

350 act_fun: torch.nn.Module = serializable_field( 

351 default=torch.nn.ReLU(), 

352 serialization_fn=lambda x: str(x), 

353 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

354 ) 

355 ``` 

356 

357 which gives us: 

358 

359 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

360 >>> s = json.dumps(nc.serialize()) 

361 >>> s 

362 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

363 >>> read_nc = NestedClass.load(json.loads(s)) 

364 >>> read_nc == nc 

365 True 

366 """ 

367 

368 def serialize(self) -> dict[str, Any]: 

369 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 

370 raise NotImplementedError( 

371 f"decorate {self.__class__ = } with `@serializable_dataclass`" 

372 ) 

373 

374 @classmethod 

375 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 

376 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 

377 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 

378 

379 def validate_fields_types( 

380 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 

381 ) -> bool: 

382 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

383 return SerializableDataclass__validate_fields_types( 

384 self, on_typecheck_error=on_typecheck_error 

385 ) 

386 

387 def validate_field_type( 

388 self, 

389 field: "SerializableField|str", 

390 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

391 ) -> bool: 

392 """given a dataclass, check the field matches the type hint""" 

393 return SerializableDataclass__validate_field_type( 

394 self, field, on_typecheck_error=on_typecheck_error 

395 ) 

396 

397 def __eq__(self, other: Any) -> bool: 

398 return dc_eq(self, other) 

399 

400 def __hash__(self) -> int: 

401 "hashes the json-serialized representation of the class" 

402 return hash(json.dumps(self.serialize())) 

403 

404 def diff( 

405 self, other: "SerializableDataclass", of_serialized: bool = False 

406 ) -> dict[str, Any]: 

407 """get a rich and recursive diff between two instances of a serializable dataclass 

408 

409 ```python 

410 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 

411 {'b': {'self': 2, 'other': 3}} 

412 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 

413 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 

414 ``` 

415 

416 # Parameters: 

417 - `other : SerializableDataclass` 

418 other instance to compare against 

419 - `of_serialized : bool` 

420 if true, compare serialized data and not raw values 

421 (defaults to `False`) 

422 

423 # Returns: 

424 - `dict[str, Any]` 

425 

426 

427 # Raises: 

428 - `ValueError` : if the instances are not of the same type 

429 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 

430 """ 

431 # match types 

432 if type(self) is not type(other): 

433 raise ValueError( 

434 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 

435 ) 

436 

437 # initialize the diff result 

438 diff_result: dict = {} 

439 

440 # if they are the same, return the empty diff 

441 try: 

442 if self == other: 

443 return diff_result 

444 except Exception: 

445 pass 

446 

447 # if we are working with serialized data, serialize the instances 

448 if of_serialized: 

449 ser_self: dict = self.serialize() 

450 ser_other: dict = other.serialize() 

451 

452 # for each field in the class 

453 for field in dataclasses.fields(self): # type: ignore[arg-type] 

454 # skip fields that are not for comparison 

455 if not field.compare: 

456 continue 

457 

458 # get values 

459 field_name: str = field.name 

460 self_value = getattr(self, field_name) 

461 other_value = getattr(other, field_name) 

462 

463 # if the values are both serializable dataclasses, recurse 

464 if isinstance(self_value, SerializableDataclass) and isinstance( 

465 other_value, SerializableDataclass 

466 ): 

467 nested_diff: dict = self_value.diff( 

468 other_value, of_serialized=of_serialized 

469 ) 

470 if nested_diff: 

471 diff_result[field_name] = nested_diff 

472 # only support serializable dataclasses 

473 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 

474 other_value 

475 ): 

476 raise ValueError("Non-serializable dataclass is not supported") 

477 else: 

478 # get the values of either the serialized or the actual values 

479 self_value_s = ser_self[field_name] if of_serialized else self_value 

480 other_value_s = ser_other[field_name] if of_serialized else other_value 

481 # compare the values 

482 if not array_safe_eq(self_value_s, other_value_s): 

483 diff_result[field_name] = {"self": self_value, "other": other_value} 

484 

485 # return the diff result 

486 return diff_result 

487 

488 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 

489 """update the instance from a nested dict, useful for configuration from command line args 

490 

491 # Parameters: 

492 - `nested_dict : dict[str, Any]` 

493 nested dict to update the instance with 

494 """ 

495 for field in dataclasses.fields(self): # type: ignore[arg-type] 

496 field_name: str = field.name 

497 self_value = getattr(self, field_name) 

498 

499 if field_name in nested_dict: 

500 if isinstance(self_value, SerializableDataclass): 

501 self_value.update_from_nested_dict(nested_dict[field_name]) 

502 else: 

503 setattr(self, field_name, nested_dict[field_name]) 

504 

505 def __copy__(self) -> "SerializableDataclass": 

506 "deep copy by serializing and loading the instance to json" 

507 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

508 

509 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 

510 "deep copy by serializing and loading the instance to json" 

511 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

512 

513 

514# cache this so we don't have to keep getting it 

515# TODO: are the types hashable? does this even make sense? 

516@functools.lru_cache(typed=True) 

517def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 

518 "cached typing.get_type_hints for a class" 

519 return typing.get_type_hints(cls) 

520 

521 

522def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 

523 "helper function to get type hints for a class" 

524 cls_type_hints: dict[str, Any] 

525 try: 

526 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 

527 if len(cls_type_hints) == 0: 

528 cls_type_hints = typing.get_type_hints(cls) 

529 

530 if len(cls_type_hints) == 0: 

531 raise ValueError(f"empty type hints for {cls.__name__ = }") 

532 except (TypeError, NameError, ValueError) as e: 

533 raise TypeError( 

534 f"Cannot get type hints for {cls = }\n" 

535 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 

536 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 

537 + f" {e = }" 

538 ) from e 

539 

540 return cls_type_hints 

541 

542 

543class KWOnlyError(NotImplementedError): 

544 "kw-only dataclasses are not supported in python <3.9" 

545 

546 pass 

547 

548 

549class FieldError(ValueError): 

550 "base class for field errors" 

551 

552 pass 

553 

554 

555class NotSerializableFieldException(FieldError): 

556 "field is not a `SerializableField`" 

557 

558 pass 

559 

560 

561class FieldSerializationError(FieldError): 

562 "error while serializing a field" 

563 

564 pass 

565 

566 

567class FieldLoadingError(FieldError): 

568 "error while loading a field" 

569 

570 pass 

571 

572 

573class FieldTypeMismatchError(FieldError, TypeError): 

574 "error when a field type does not match the type hint" 

575 

576 pass 

577 

578 

579@dataclass_transform( 

580 field_specifiers=(serializable_field, SerializableField), 

581) 

582def serializable_dataclass( 

583 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 

584 _cls=None, # type: ignore 

585 *, 

586 init: bool = True, 

587 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 

588 eq: bool = True, 

589 order: bool = False, 

590 unsafe_hash: bool = False, 

591 frozen: bool = False, 

592 properties_to_serialize: Optional[list[str]] = None, 

593 register_handler: bool = True, 

594 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

595 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 

596 methods_no_override: list[str] | None = None, 

597 **kwargs, 

598): 

599 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 

600 

601 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 

602 

603 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 

604 

605 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 

606 

607 Examines PEP 526 `__annotations__` to determine fields. 

608 

609 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 

610 

611 ```python 

612 @serializable_dataclass(kw_only=True) 

613 class Myclass(SerializableDataclass): 

614 a: int 

615 b: str 

616 ``` 

617 ```python 

618 >>> Myclass(a=1, b="q").serialize() 

619 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 

620 ``` 

621 

622 # Parameters: 

623 

624 - `_cls : _type_` 

625 class to decorate. don't pass this arg, just use this as a decorator 

626 (defaults to `None`) 

627 - `init : bool` 

628 whether to add an `__init__` method 

629 *(passed to dataclasses.dataclass)* 

630 (defaults to `True`) 

631 - `repr : bool` 

632 whether to add a `__repr__` method 

633 *(passed to dataclasses.dataclass)* 

634 (defaults to `True`) 

635 - `order : bool` 

636 whether to add rich comparison methods 

637 *(passed to dataclasses.dataclass)* 

638 (defaults to `False`) 

639 - `unsafe_hash : bool` 

640 whether to add a `__hash__` method 

641 *(passed to dataclasses.dataclass)* 

642 (defaults to `False`) 

643 - `frozen : bool` 

644 whether to make the class frozen 

645 *(passed to dataclasses.dataclass)* 

646 (defaults to `False`) 

647 - `properties_to_serialize : Optional[list[str]]` 

648 which properties to add to the serialized data dict 

649 **SerializableDataclass only** 

650 (defaults to `None`) 

651 - `register_handler : bool` 

652 if true, register the class with ZANJ for loading 

653 **SerializableDataclass only** 

654 (defaults to `True`) 

655 - `on_typecheck_error : ErrorMode` 

656 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 

657 **SerializableDataclass only** 

658 - `on_typecheck_mismatch : ErrorMode` 

659 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 

660 **SerializableDataclass only** 

661 - `methods_no_override : list[str]|None` 

662 list of methods that should not be overridden by the decorator 

663 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 

664 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 

665 **SerializableDataclass only** 

666 (defaults to `None`) 

667 - `**kwargs` 

668 *(passed to dataclasses.dataclass)* 

669 

670 # Returns: 

671 

672 - `_type_` 

673 the decorated class 

674 

675 # Raises: 

676 

677 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 

678 - `NotSerializableFieldException` : if a field is not a `SerializableField` 

679 - `FieldSerializationError` : if there is an error serializing a field 

680 - `AttributeError` : if a property is not found on the class 

681 - `FieldLoadingError` : if there is an error loading a field 

682 """ 

683 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 

684 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

685 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 

686 

687 if properties_to_serialize is None: 

688 _properties_to_serialize: list = list() 

689 else: 

690 _properties_to_serialize = properties_to_serialize 

691 

692 def wrap(cls: Type[T]) -> Type[T]: 

693 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 

694 for field_name, field_type in cls.__annotations__.items(): 

695 field_value = getattr(cls, field_name, None) 

696 if not isinstance(field_value, SerializableField): 

697 if isinstance(field_value, dataclasses.Field): 

698 # Convert the field to a SerializableField while preserving properties 

699 field_value = SerializableField.from_Field(field_value) 

700 else: 

701 # Create a new SerializableField 

702 field_value = serializable_field() 

703 setattr(cls, field_name, field_value) 

704 

705 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 

706 if sys.version_info < (3, 10): 

707 if "kw_only" in kwargs: 

708 if kwargs["kw_only"] == True: # noqa: E712 

709 raise KWOnlyError( 

710 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 

711 ) 

712 else: 

713 del kwargs["kw_only"] 

714 

715 # call `dataclasses.dataclass` to set some stuff up 

716 cls = dataclasses.dataclass( # type: ignore[call-overload] 

717 cls, 

718 init=init, 

719 repr=repr, 

720 eq=eq, 

721 order=order, 

722 unsafe_hash=unsafe_hash, 

723 frozen=frozen, 

724 **kwargs, 

725 ) 

726 

727 # copy these to the class 

728 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 

729 

730 # ====================================================================== 

731 # define `serialize` func 

732 # done locally since it depends on args to the decorator 

733 # ====================================================================== 

734 def serialize(self) -> dict[str, Any]: 

735 result: dict[str, Any] = { 

736 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 

737 } 

738 # for each field in the class 

739 for field in dataclasses.fields(self): # type: ignore[arg-type] 

740 # need it to be our special SerializableField 

741 if not isinstance(field, SerializableField): 

742 raise NotSerializableFieldException( 

743 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 

744 f"but a {type(field)} " 

745 "this state should be inaccessible, please report this bug!" 

746 ) 

747 

748 # try to save it 

749 if field.serialize: 

750 try: 

751 # get the val 

752 value = getattr(self, field.name) 

753 # if it is a serializable dataclass, serialize it 

754 if isinstance(value, SerializableDataclass): 

755 value = value.serialize() 

756 # if the value has a serialization function, use that 

757 if hasattr(value, "serialize") and callable(value.serialize): 

758 value = value.serialize() 

759 # if the field has a serialization function, use that 

760 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 

761 elif field.serialization_fn: 

762 value = field.serialization_fn(value) 

763 

764 # store the value in the result 

765 result[field.name] = value 

766 except Exception as e: 

767 raise FieldSerializationError( 

768 "\n".join( 

769 [ 

770 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 

771 f"{field = }", 

772 f"{value = }", 

773 f"{self = }", 

774 ] 

775 ) 

776 ) from e 

777 

778 # store each property if we can get it 

779 for prop in self._properties_to_serialize: 

780 if hasattr(cls, prop): 

781 value = getattr(self, prop) 

782 result[prop] = value 

783 else: 

784 raise AttributeError( 

785 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 

786 + f"but it is in {self._properties_to_serialize = }" 

787 + f"\n{self = }" 

788 ) 

789 

790 return result 

791 

792 # ====================================================================== 

793 # define `load` func 

794 # done locally since it depends on args to the decorator 

795 # ====================================================================== 

796 # mypy thinks this isnt a classmethod 

797 @classmethod # type: ignore[misc] 

798 def load(cls, data: dict[str, Any] | T) -> Type[T]: 

799 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 

800 if isinstance(data, cls): 

801 return data 

802 

803 assert isinstance( 

804 data, typing.Mapping 

805 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 

806 

807 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 

808 

809 # initialize dict for keeping what we will pass to the constructor 

810 ctor_kwargs: dict[str, Any] = dict() 

811 

812 # iterate over the fields of the class 

813 for field in dataclasses.fields(cls): 

814 # check if the field is a SerializableField 

815 assert isinstance( 

816 field, SerializableField 

817 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 

818 

819 # check if the field is in the data and if it should be initialized 

820 if (field.name in data) and field.init: 

821 # get the value, we will be processing it 

822 value: Any = data[field.name] 

823 

824 # get the type hint for the field 

825 field_type_hint: Any = cls_type_hints.get(field.name, None) 

826 

827 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 

828 if field.deserialize_fn: 

829 # if it has a deserialization function, use that 

830 value = field.deserialize_fn(value) 

831 elif field.loading_fn: 

832 # if it has a loading function, use that 

833 value = field.loading_fn(data) 

834 elif ( 

835 field_type_hint is not None 

836 and hasattr(field_type_hint, "load") 

837 and callable(field_type_hint.load) 

838 ): 

839 # if no loading function but has a type hint with a load method, use that 

840 if isinstance(value, dict): 

841 value = field_type_hint.load(value) 

842 else: 

843 raise FieldLoadingError( 

844 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 

845 ) 

846 else: 

847 # assume no loading needs to happen, keep `value` as-is 

848 pass 

849 

850 # store the value in the constructor kwargs 

851 ctor_kwargs[field.name] = value 

852 

853 # create a new instance of the class with the constructor kwargs 

854 output: cls = cls(**ctor_kwargs) 

855 

856 # validate the types of the fields if needed 

857 if on_typecheck_mismatch != ErrorMode.IGNORE: 

858 fields_valid: dict[str, bool] = ( 

859 SerializableDataclass__validate_fields_types__dict( 

860 output, 

861 on_typecheck_error=on_typecheck_error, 

862 ) 

863 ) 

864 

865 # if there are any fields that are not valid, raise an error 

866 if not all(fields_valid.values()): 

867 msg: str = ( 

868 f"Type mismatch in fields of {cls.__name__}:\n" 

869 + "\n".join( 

870 [ 

871 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 

872 for k, v in fields_valid.items() 

873 if not v 

874 ] 

875 ) 

876 ) 

877 

878 on_typecheck_mismatch.process( 

879 msg, except_cls=FieldTypeMismatchError 

880 ) 

881 

882 # return the new instance 

883 return output 

884 

885 _methods_no_override: set[str] 

886 if methods_no_override is None: 

887 _methods_no_override = set() 

888 else: 

889 _methods_no_override = set(methods_no_override) 

890 

891 if _methods_no_override - { 

892 "__eq__", 

893 "serialize", 

894 "load", 

895 "validate_fields_types", 

896 }: 

897 warnings.warn( 

898 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 

899 ) 

900 

901 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 

902 if "serialize" not in _methods_no_override: 

903 # type is `Callable[[T], dict]` 

904 cls.serialize = serialize # type: ignore[attr-defined] 

905 if "load" not in _methods_no_override: 

906 # type is `Callable[[dict], T]` 

907 cls.load = load # type: ignore[attr-defined] 

908 

909 if "validate_field_type" not in _methods_no_override: 

910 # type is `Callable[[T, ErrorMode], bool]` 

911 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 

912 

913 if "__eq__" not in _methods_no_override: 

914 # type is `Callable[[T, T], bool]` 

915 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 

916 

917 # Register the class with ZANJ 

918 if register_handler: 

919 zanj_register_loader_serializable_dataclass(cls) 

920 

921 return cls 

922 

923 if _cls is None: 

924 return wrap 

925 else: 

926 return wrap(_cls)