Coverage for tests\unit\json_serialize\serializable_dataclass\test_serializable_dataclass.py: 87%

483 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-14 01:33 -0700

1from __future__ import annotations 

2 

3from copy import deepcopy 

4import typing 

5from typing import Any, Dict, Generic, List, Optional, TypeVar, Union 

6 

7import pytest 

8 

9from muutils.errormode import ErrorMode 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from muutils.json_serialize.serializable_dataclass import ( 

17 FieldIsNotInitOrSerializeWarning, 

18 FieldTypeMismatchError, 

19) 

20from muutils.json_serialize.util import _FORMAT_KEY 

21 

22# pylint: disable=missing-class-docstring, unused-variable 

23 

24 

25@serializable_dataclass 

26class BasicAutofields(SerializableDataclass): 

27 a: str 

28 b: int 

29 c: typing.List[int] 

30 

31 

32def test_basic_auto_fields(): 

33 data = dict(a="hello", b=42, c=[1, 2, 3]) 

34 instance = BasicAutofields(**data) 

35 data_with_format = data.copy() 

36 data_with_format[_FORMAT_KEY] = "BasicAutofields(SerializableDataclass)" 

37 assert instance.serialize() == data_with_format 

38 assert instance == instance 

39 assert instance.diff(instance) == {} 

40 

41 

42def test_basic_diff(): 

43 instance_1 = BasicAutofields(a="hello", b=42, c=[1, 2, 3]) 

44 instance_2 = BasicAutofields(a="goodbye", b=42, c=[1, 2, 3]) 

45 instance_3 = BasicAutofields(a="hello", b=-1, c=[1, 2, 3]) 

46 instance_4 = BasicAutofields(a="hello", b=-1, c=[42]) 

47 

48 assert instance_1.diff(instance_2) == {"a": {"self": "hello", "other": "goodbye"}} 

49 assert instance_1.diff(instance_3) == {"b": {"self": 42, "other": -1}} 

50 assert instance_1.diff(instance_4) == { 

51 "b": {"self": 42, "other": -1}, 

52 "c": {"self": [1, 2, 3], "other": [42]}, 

53 } 

54 assert instance_1.diff(instance_1) == {} 

55 assert instance_2.diff(instance_3) == { 

56 "a": {"self": "goodbye", "other": "hello"}, 

57 "b": {"self": 42, "other": -1}, 

58 } 

59 

60 

61@serializable_dataclass 

62class SimpleFields(SerializableDataclass): 

63 d: str 

64 e: int = 42 

65 f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821 

66 

67 

68@serializable_dataclass 

69class FieldOptions(SerializableDataclass): 

70 a: str = serializable_field() 

71 b: str = serializable_field() 

72 c: str = serializable_field(init=False, serialize=False, repr=False, compare=False) 

73 d: str = serializable_field( 

74 serialization_fn=lambda x: x.upper(), loading_fn=lambda x: x["d"].lower() 

75 ) 

76 

77 

78@serializable_dataclass(properties_to_serialize=["full_name"]) 

79class WithProperty(SerializableDataclass): 

80 first_name: str 

81 last_name: str 

82 

83 @property 

84 def full_name(self) -> str: 

85 return f"{self.first_name} {self.last_name}" 

86 

87 

88class Child(FieldOptions, WithProperty): 

89 pass 

90 

91 

92@pytest.fixture 

93def simple_fields_instance(): 

94 return SimpleFields(d="hello", e=42, f=[1, 2, 3]) 

95 

96 

97@pytest.fixture 

98def field_options_instance(): 

99 return FieldOptions(a="hello", b="world", d="case") 

100 

101 

102@pytest.fixture 

103def with_property_instance(): 

104 return WithProperty(first_name="John", last_name="Doe") 

105 

106 

107def test_simple_fields_serialization(simple_fields_instance): 

108 serialized = simple_fields_instance.serialize() 

109 assert serialized == { 

110 "d": "hello", 

111 "e": 42, 

112 "f": [1, 2, 3], 

113 _FORMAT_KEY: "SimpleFields(SerializableDataclass)", 

114 } 

115 

116 

117def test_simple_fields_loading(simple_fields_instance): 

118 serialized = simple_fields_instance.serialize() 

119 

120 loaded = SimpleFields.load(serialized) 

121 

122 assert loaded == simple_fields_instance 

123 assert loaded.diff(simple_fields_instance) == {} 

124 assert simple_fields_instance.diff(loaded) == {} 

125 

126 

127def test_field_options_serialization(field_options_instance): 

128 serialized = field_options_instance.serialize() 

129 assert serialized == { 

130 "a": "hello", 

131 "b": "world", 

132 "d": "CASE", 

133 _FORMAT_KEY: "FieldOptions(SerializableDataclass)", 

134 } 

135 

136 

137def test_field_options_loading(field_options_instance): 

138 # ignore a `FieldIsNotInitOrSerializeWarning` 

139 serialized = field_options_instance.serialize() 

140 with pytest.warns(FieldIsNotInitOrSerializeWarning): 

141 loaded = FieldOptions.load(serialized) 

142 assert loaded == field_options_instance 

143 

144 

145def test_with_property_serialization(with_property_instance): 

146 serialized = with_property_instance.serialize() 

147 assert serialized == { 

148 "first_name": "John", 

149 "last_name": "Doe", 

150 "full_name": "John Doe", 

151 _FORMAT_KEY: "WithProperty(SerializableDataclass)", 

152 } 

153 

154 

155def test_with_property_loading(with_property_instance): 

156 serialized = with_property_instance.serialize() 

157 loaded = WithProperty.load(serialized) 

158 assert loaded == with_property_instance 

159 

160 

161@serializable_dataclass 

162class Address(SerializableDataclass): 

163 street: str 

164 city: str 

165 zip_code: str 

166 

167 

168@serializable_dataclass 

169class Person(SerializableDataclass): 

170 name: str 

171 age: int 

172 address: Address 

173 

174 

175@pytest.fixture 

176def address_instance(): 

177 return Address(street="123 Main St", city="New York", zip_code="10001") 

178 

179 

180@pytest.fixture 

181def person_instance(address_instance): 

182 return Person(name="John Doe", age=30, address=address_instance) 

183 

184 

185def test_nested_serialization(person_instance): 

186 serialized = person_instance.serialize() 

187 expected_ser = { 

188 "name": "John Doe", 

189 "age": 30, 

190 "address": { 

191 "street": "123 Main St", 

192 "city": "New York", 

193 "zip_code": "10001", 

194 _FORMAT_KEY: "Address(SerializableDataclass)", 

195 }, 

196 _FORMAT_KEY: "Person(SerializableDataclass)", 

197 } 

198 assert serialized == expected_ser 

199 

200 

201def test_nested_loading(person_instance): 

202 serialized = person_instance.serialize() 

203 loaded = Person.load(serialized) 

204 assert loaded == person_instance 

205 assert loaded.address == person_instance.address 

206 

207 

208def test_with_printing(): 

209 @serializable_dataclass(properties_to_serialize=["full_name"]) 

210 class MyClass(SerializableDataclass): 

211 name: str 

212 age: int = serializable_field( 

213 serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1 

214 ) 

215 items: list = serializable_field(default_factory=list) 

216 

217 @property 

218 def full_name(self) -> str: 

219 return f"{self.name} Doe" 

220 

221 # Usage 

222 my_instance = MyClass(name="John", age=30, items=["apple", "banana"]) 

223 serialized_data = my_instance.serialize() 

224 print(serialized_data) 

225 

226 loaded_instance = MyClass.load(serialized_data) 

227 print(loaded_instance) 

228 

229 

230def test_simple_class_serialization(): 

231 @serializable_dataclass 

232 class SimpleClass(SerializableDataclass): 

233 a: int 

234 b: str 

235 

236 simple = SimpleClass(a=42, b="hello") 

237 serialized = simple.serialize() 

238 assert serialized == { 

239 "a": 42, 

240 "b": "hello", 

241 _FORMAT_KEY: "SimpleClass(SerializableDataclass)", 

242 } 

243 

244 loaded = SimpleClass.load(serialized) 

245 assert loaded == simple 

246 

247 

248def test_error_when_init_and_not_serialize(): 

249 with pytest.raises(ValueError): 

250 

251 @serializable_dataclass 

252 class SimpleClass(SerializableDataclass): 

253 a: int = serializable_field(init=True, serialize=False) 

254 

255 

256def test_person_serialization(): 

257 @serializable_dataclass(properties_to_serialize=["full_name"]) 

258 class FullPerson(SerializableDataclass): 

259 name: str = serializable_field() 

260 age: int = serializable_field(default=-1) 

261 items: typing.List[str] = serializable_field(default_factory=list) 

262 

263 @property 

264 def full_name(self) -> str: 

265 return f"{self.name} Doe" 

266 

267 person = FullPerson(name="John", items=["apple", "banana"]) 

268 serialized = person.serialize() 

269 expected_ser = { 

270 "name": "John", 

271 "age": -1, 

272 "items": ["apple", "banana"], 

273 "full_name": "John Doe", 

274 _FORMAT_KEY: "FullPerson(SerializableDataclass)", 

275 } 

276 assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" 

277 

278 loaded = FullPerson.load(serialized) 

279 

280 assert loaded == person 

281 

282 

283def test_custom_serialization(): 

284 @serializable_dataclass 

285 class CustomSerialization(SerializableDataclass): 

286 data: Any = serializable_field( 

287 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["data"] // 2 

288 ) 

289 

290 custom = CustomSerialization(data=5) 

291 serialized = custom.serialize() 

292 assert serialized == { 

293 "data": 10, 

294 _FORMAT_KEY: "CustomSerialization(SerializableDataclass)", 

295 } 

296 

297 loaded = CustomSerialization.load(serialized) 

298 assert loaded == custom 

299 

300 

301@serializable_dataclass 

302class Nested_with_Container(SerializableDataclass): 

303 val_int: int 

304 val_str: str 

305 val_list: typing.List[BasicAutofields] = serializable_field( 

306 default_factory=list, 

307 serialization_fn=lambda x: [y.serialize() for y in x], 

308 loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]], 

309 ) 

310 

311 

312def test_nested_with_container(): 

313 instance = Nested_with_Container( 

314 val_int=42, 

315 val_str="hello", 

316 val_list=[ 

317 BasicAutofields(a="a", b=1, c=[1, 2, 3]), 

318 BasicAutofields(a="b", b=2, c=[4, 5, 6]), 

319 ], 

320 ) 

321 

322 serialized = instance.serialize() 

323 expected_ser = { 

324 "val_int": 42, 

325 "val_str": "hello", 

326 "val_list": [ 

327 { 

328 "a": "a", 

329 "b": 1, 

330 "c": [1, 2, 3], 

331 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)", 

332 }, 

333 { 

334 "a": "b", 

335 "b": 2, 

336 "c": [4, 5, 6], 

337 _FORMAT_KEY: "BasicAutofields(SerializableDataclass)", 

338 }, 

339 ], 

340 _FORMAT_KEY: "Nested_with_Container(SerializableDataclass)", 

341 } 

342 

343 assert serialized == expected_ser 

344 

345 loaded = Nested_with_Container.load(serialized) 

346 

347 assert loaded == instance 

348 

349 

350class Custom_class_with_serialization: 

351 """custom class which doesnt inherit but does serialize""" 

352 

353 def __init__(self, a: int, b: str): 

354 self.a: int = a 

355 self.b: str = b 

356 

357 def serialize(self): 

358 return {"a": self.a, "b": self.b} 

359 

360 @classmethod 

361 def load(cls, data): 

362 return cls(data["a"], data["b"]) 

363 

364 def __eq__(self, other): 

365 return (self.a == other.a) and (self.b == other.b) 

366 

367 

368@serializable_dataclass 

369class nested_custom(SerializableDataclass): 

370 value: float 

371 data1: Custom_class_with_serialization 

372 

373 

374def test_nested_custom(recwarn): # this will send some warnings but whatever 

375 instance = nested_custom( 

376 value=42.0, data1=Custom_class_with_serialization(1, "hello") 

377 ) 

378 serialized = instance.serialize() 

379 expected_ser = { 

380 "value": 42.0, 

381 "data1": {"a": 1, "b": "hello"}, 

382 _FORMAT_KEY: "nested_custom(SerializableDataclass)", 

383 } 

384 assert serialized == expected_ser 

385 loaded = nested_custom.load(serialized) 

386 assert loaded == instance 

387 

388 

389def test_deserialize_fn(): 

390 @serializable_dataclass 

391 class DeserializeFn(SerializableDataclass): 

392 data: int = serializable_field( 

393 serialization_fn=lambda x: str(x), 

394 deserialize_fn=lambda x: int(x), 

395 ) 

396 

397 instance = DeserializeFn(data=5) 

398 serialized = instance.serialize() 

399 assert serialized == { 

400 "data": "5", 

401 _FORMAT_KEY: "DeserializeFn(SerializableDataclass)", 

402 } 

403 

404 loaded = DeserializeFn.load(serialized) 

405 assert loaded == instance 

406 assert loaded.data == 5 

407 

408 

409@serializable_dataclass 

410class DictContainer(SerializableDataclass): 

411 """Test class containing a dictionary field""" 

412 

413 simple_dict: Dict[str, int] 

414 nested_dict: Dict[str, Dict[str, int]] = serializable_field(default_factory=dict) 

415 optional_dict: Dict[str, str] = serializable_field(default_factory=dict) 

416 

417 

418def test_dict_serialization(): 

419 """Test serialization of dictionaries within SerializableDataclass""" 

420 data = DictContainer( 

421 simple_dict={"a": 1, "b": 2}, 

422 nested_dict={"x": {"y": 3, "z": 4}}, 

423 optional_dict={"hello": "world"}, 

424 ) 

425 

426 serialized = data.serialize() 

427 expected = { 

428 _FORMAT_KEY: "DictContainer(SerializableDataclass)", 

429 "simple_dict": {"a": 1, "b": 2}, 

430 "nested_dict": {"x": {"y": 3, "z": 4}}, 

431 "optional_dict": {"hello": "world"}, 

432 } 

433 

434 assert serialized == expected 

435 

436 

437def test_dict_loading(): 

438 """Test loading dictionaries into SerializableDataclass""" 

439 original_data = { 

440 _FORMAT_KEY: "DictContainer(SerializableDataclass)", 

441 "simple_dict": {"a": 1, "b": 2}, 

442 "nested_dict": {"x": {"y": 3, "z": 4}}, 

443 "optional_dict": {"hello": "world"}, 

444 } 

445 

446 loaded = DictContainer.load(original_data) 

447 assert loaded.simple_dict == {"a": 1, "b": 2} 

448 assert loaded.nested_dict == {"x": {"y": 3, "z": 4}} 

449 assert loaded.optional_dict == {"hello": "world"} 

450 

451 

452def test_dict_equality(): 

453 """Test equality comparison of dictionaries within SerializableDataclass""" 

454 instance1 = DictContainer( 

455 simple_dict={"a": 1, "b": 2}, 

456 nested_dict={"x": {"y": 3, "z": 4}}, 

457 optional_dict={"hello": "world"}, 

458 ) 

459 

460 instance2 = DictContainer( 

461 simple_dict={"a": 1, "b": 2}, 

462 nested_dict={"x": {"y": 3, "z": 4}}, 

463 optional_dict={"hello": "world"}, 

464 ) 

465 

466 instance3 = DictContainer( 

467 simple_dict={"a": 1, "b": 3}, # Different value 

468 nested_dict={"x": {"y": 3, "z": 4}}, 

469 optional_dict={"hello": "world"}, 

470 ) 

471 

472 assert instance1 == instance2 

473 assert instance1 != instance3 

474 assert instance2 != instance3 

475 

476 

477def test_dict_diff(): 

478 """Test diff functionality with dictionaries""" 

479 instance1 = DictContainer( 

480 simple_dict={"a": 1, "b": 2}, 

481 nested_dict={"x": {"y": 3, "z": 4}}, 

482 optional_dict={"hello": "world"}, 

483 ) 

484 

485 # Different simple_dict value 

486 instance2 = DictContainer( 

487 simple_dict={"a": 1, "b": 3}, 

488 nested_dict={"x": {"y": 3, "z": 4}}, 

489 optional_dict={"hello": "world"}, 

490 ) 

491 

492 # Different nested_dict value 

493 instance3 = DictContainer( 

494 simple_dict={"a": 1, "b": 2}, 

495 nested_dict={"x": {"y": 3, "z": 5}}, 

496 optional_dict={"hello": "world"}, 

497 ) 

498 

499 # Different optional_dict value 

500 instance4 = DictContainer( 

501 simple_dict={"a": 1, "b": 2}, 

502 nested_dict={"x": {"y": 3, "z": 4}}, 

503 optional_dict={"hello": "python"}, 

504 ) 

505 

506 # Test diff with simple_dict changes 

507 diff1 = instance1.diff(instance2) 

508 assert diff1 == { 

509 "simple_dict": {"self": {"a": 1, "b": 2}, "other": {"a": 1, "b": 3}} 

510 } 

511 

512 # Test diff with nested_dict changes 

513 diff2 = instance1.diff(instance3) 

514 assert diff2 == { 

515 "nested_dict": { 

516 "self": {"x": {"y": 3, "z": 4}}, 

517 "other": {"x": {"y": 3, "z": 5}}, 

518 } 

519 } 

520 

521 # Test diff with optional_dict changes 

522 diff3 = instance1.diff(instance4) 

523 assert diff3 == { 

524 "optional_dict": {"self": {"hello": "world"}, "other": {"hello": "python"}} 

525 } 

526 

527 # Test no diff when comparing identical instances 

528 assert instance1.diff(instance1) == {} 

529 

530 

531@serializable_dataclass 

532class ComplexDictContainer(SerializableDataclass): 

533 """Test class with more complex dictionary structures""" 

534 

535 mixed_dict: Dict[str, Any] 

536 list_dict: Dict[str, typing.List[int]] 

537 multi_nested: Dict[str, Dict[str, Dict[str, int]]] 

538 

539 

540def test_complex_dict_serialization(): 

541 """Test serialization of more complex dictionary structures""" 

542 data = ComplexDictContainer( 

543 mixed_dict={"str": "hello", "int": 42, "list": [1, 2, 3]}, 

544 list_dict={"a": [1, 2, 3], "b": [4, 5, 6]}, 

545 multi_nested={"x": {"y": {"z": 1, "w": 2}, "v": {"u": 3, "t": 4}}}, 

546 ) 

547 

548 serialized = data.serialize() 

549 loaded = ComplexDictContainer.load(serialized) 

550 assert loaded == data 

551 assert loaded.diff(data) == {} 

552 

553 

554def test_empty_dicts(): 

555 """Test handling of empty dictionaries""" 

556 data = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

557 

558 serialized = data.serialize() 

559 loaded = DictContainer.load(serialized) 

560 assert loaded == data 

561 assert loaded.diff(data) == {} 

562 

563 # Test equality with another empty instance 

564 another_empty = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

565 assert data == another_empty 

566 

567 

568# Test invalid dictionary type validation 

569@serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

570class StrictDictContainer(SerializableDataclass): 

571 """Test class with strict dictionary typing""" 

572 

573 int_dict: Dict[str, int] 

574 str_dict: Dict[str, str] 

575 float_dict: Dict[str, float] 

576 

577 

578# TODO: figure this out 

579@pytest.mark.skip(reason="dict type validation doesnt seem to work") 

580def test_dict_type_validation(): 

581 """Test type validation for dictionary values""" 

582 # Valid case 

583 valid = StrictDictContainer( 

584 int_dict={"a": 1, "b": 2}, 

585 str_dict={"x": "hello", "y": "world"}, 

586 float_dict={"m": 1.0, "n": 2.5}, 

587 ) 

588 assert valid.validate_fields_types() 

589 

590 # Invalid int_dict 

591 with pytest.raises(FieldTypeMismatchError): 

592 StrictDictContainer( 

593 int_dict={"a": "not an int"}, # Type error 

594 str_dict={"x": "hello"}, 

595 float_dict={"m": 1.0}, 

596 ) 

597 

598 # Invalid str_dict 

599 with pytest.raises(FieldTypeMismatchError): 

600 StrictDictContainer( 

601 int_dict={"a": 1}, 

602 str_dict={"x": 123}, # Type error 

603 float_dict={"m": 1.0}, 

604 ) 

605 

606 

607# Test dictionary with optional values 

608@serializable_dataclass 

609class OptionalDictContainer(SerializableDataclass): 

610 """Test class with optional dictionary values""" 

611 

612 optional_values: Dict[str, Optional[int]] 

613 union_values: Dict[str, Union[int, str]] 

614 nullable_dict: Optional[Dict[str, int]] = None 

615 

616 

617def test_optional_dict_values(): 

618 """Test dictionaries with optional/union values""" 

619 instance = OptionalDictContainer( 

620 optional_values={"a": 1, "b": None, "c": 3}, 

621 union_values={"x": 1, "y": "string", "z": 42}, 

622 nullable_dict={"m": 1, "n": 2}, 

623 ) 

624 

625 serialized = instance.serialize() 

626 loaded = OptionalDictContainer.load(serialized) 

627 assert loaded == instance 

628 

629 # Test with None dict 

630 instance2 = OptionalDictContainer( 

631 optional_values={"a": None, "b": None}, 

632 union_values={"x": "all strings", "y": "here"}, 

633 nullable_dict=None, 

634 ) 

635 

636 serialized2 = instance2.serialize() 

637 loaded2 = OptionalDictContainer.load(serialized2) 

638 assert loaded2 == instance2 

639 

640 

641# Test dictionary mutation 

642def test_dict_mutation(): 

643 """Test behavior when mutating dictionary contents""" 

644 instance1 = DictContainer( 

645 simple_dict={"a": 1, "b": 2}, 

646 nested_dict={"x": {"y": 3}}, 

647 optional_dict={"hello": "world"}, 

648 ) 

649 

650 instance2 = deepcopy(instance1) 

651 

652 # Mutate dictionary in instance1 

653 instance1.simple_dict["c"] = 3 

654 instance1.nested_dict["x"]["z"] = 4 

655 instance1.optional_dict["new"] = "value" 

656 

657 # Verify instance2 was not affected 

658 assert instance2.simple_dict == {"a": 1, "b": 2} 

659 assert instance2.nested_dict == {"x": {"y": 3}} 

660 assert instance2.optional_dict == {"hello": "world"} 

661 

662 # Verify diff shows the changes 

663 diff = instance2.diff(instance1) 

664 assert "simple_dict" in diff 

665 assert "nested_dict" in diff 

666 assert "optional_dict" in diff 

667 

668 

669# Test dictionary key types 

670@serializable_dataclass 

671class IntKeyDictContainer(SerializableDataclass): 

672 """Test class with non-string dictionary keys""" 

673 

674 int_keys: Dict[int, str] = serializable_field( 

675 serialization_fn=lambda x: {str(k): v for k, v in x.items()}, 

676 loading_fn=lambda x: {int(k): v for k, v in x["int_keys"].items()}, 

677 ) 

678 

679 

680def test_non_string_dict_keys(): 

681 """Test handling of dictionaries with non-string keys""" 

682 instance = IntKeyDictContainer(int_keys={1: "one", 2: "two", 3: "three"}) 

683 

684 serialized = instance.serialize() 

685 # Keys should be converted to strings in serialized form 

686 assert all(isinstance(k, str) for k in serialized["int_keys"].keys()) 

687 

688 loaded = IntKeyDictContainer.load(serialized) 

689 # Keys should be integers again after loading 

690 assert all(isinstance(k, int) for k in loaded.int_keys.keys()) 

691 assert loaded == instance 

692 

693 

694@serializable_dataclass 

695class RecursiveDictContainer(SerializableDataclass): 

696 """Test class with recursively defined dictionary type""" 

697 

698 data: Dict[str, Any] 

699 

700 

701def test_recursive_dict_structure(): 

702 """Test handling of recursively nested dictionaries""" 

703 deep_dict = { 

704 "level1": { 

705 "level2": {"level3": {"value": 42, "list": [1, 2, {"nested": "value"}]}} 

706 } 

707 } 

708 

709 instance = RecursiveDictContainer(data=deep_dict) 

710 serialized = instance.serialize() 

711 loaded = RecursiveDictContainer.load(serialized) 

712 

713 assert loaded == instance 

714 assert loaded.data == deep_dict 

715 

716 

717# need to define this outside, otherwise the validator cant see it? 

718class CustomSerializable: 

719 def __init__(self, value): 

720 self.value: Union[str, int] = value 

721 

722 def serialize(self): 

723 return {"value": self.value} 

724 

725 @classmethod 

726 def load(cls, data): 

727 return cls(data["value"]) 

728 

729 def __eq__(self, other): 

730 return isinstance(other, CustomSerializable) and self.value == other.value 

731 

732 

733def test_dict_with_custom_objects(): 

734 """Test dictionaries containing custom objects that implement serialize/load""" 

735 

736 @serializable_dataclass 

737 class CustomObjectDict(SerializableDataclass): 

738 objects: Dict[str, CustomSerializable] 

739 

740 instance = CustomObjectDict( 

741 objects={"a": CustomSerializable(42), "b": CustomSerializable("hello")} 

742 ) 

743 

744 serialized = instance.serialize() 

745 loaded = CustomObjectDict.load(serialized) 

746 assert loaded == instance 

747 

748 

749def test_empty_optional_dicts(): 

750 """Test handling of None vs empty dict in optional dictionary fields""" 

751 

752 @serializable_dataclass 

753 class OptionalDictFields(SerializableDataclass): 

754 required_dict: Dict[str, int] 

755 optional_dict: Optional[Dict[str, int]] = None 

756 default_empty: Dict[str, int] = serializable_field(default_factory=dict) 

757 

758 # Test with None 

759 instance1 = OptionalDictFields(required_dict={"a": 1}, optional_dict=None) 

760 

761 # Test with empty dict 

762 instance2 = OptionalDictFields(required_dict={"a": 1}, optional_dict={}) 

763 

764 serialized1 = instance1.serialize() 

765 serialized2 = instance2.serialize() 

766 

767 loaded1 = OptionalDictFields.load(serialized1) 

768 loaded2 = OptionalDictFields.load(serialized2) 

769 

770 assert loaded1.optional_dict is None 

771 assert loaded2.optional_dict == {} 

772 assert loaded1.default_empty == {} 

773 assert loaded2.default_empty == {} 

774 

775 

776# Test inheritance hierarchies 

777@serializable_dataclass( 

778 on_typecheck_error=ErrorMode.EXCEPT, on_typecheck_mismatch=ErrorMode.EXCEPT 

779) 

780class BaseClass(SerializableDataclass): 

781 """Base class for testing inheritance""" 

782 

783 base_field: str 

784 shared_field: int = 0 

785 

786 

787@serializable_dataclass 

788class ChildClass(BaseClass): 

789 """Child class inheriting from BaseClass""" 

790 

791 child_field: float 

792 shared_field: int = 1 # Override base class field 

793 

794 

795@serializable_dataclass 

796class GrandchildClass(ChildClass): 

797 """Grandchild class for deep inheritance testing""" 

798 

799 grandchild_field: bool 

800 

801 

802def test_inheritance(): 

803 """Test inheritance behavior of serializable dataclasses""" 

804 instance = GrandchildClass( 

805 base_field="base", shared_field=42, child_field=3.14, grandchild_field=True 

806 ) 

807 

808 serialized = instance.serialize() 

809 assert serialized["base_field"] == "base" 

810 assert serialized["shared_field"] == 42 

811 assert serialized["child_field"] == 3.14 

812 assert serialized["grandchild_field"] is True 

813 

814 loaded = GrandchildClass.load(serialized) 

815 assert loaded == instance 

816 

817 # Test that we can load as parent class 

818 base_loaded = BaseClass.load({"base_field": "test", "shared_field": 1}) 

819 assert isinstance(base_loaded, BaseClass) 

820 assert not isinstance(base_loaded, ChildClass) 

821 

822 

823@pytest.mark.skip( 

824 reason="Not implemented yet, generic types not supported and throw a `TypeHintNotImplementedError`" 

825) 

826def test_generic_types(): 

827 """Test handling of generic type parameters""" 

828 

829 T = TypeVar("T") 

830 

831 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

832 class GenericContainer(SerializableDataclass, Generic[T]): 

833 """Test generic type parameters""" 

834 

835 value: T 

836 values: List[T] 

837 

838 # Test with int 

839 int_container = GenericContainer[int](value=42, values=[1, 2, 3]) 

840 serialized = int_container.serialize() 

841 loaded = GenericContainer[int].load(serialized) 

842 assert loaded == int_container 

843 

844 # Test with str 

845 str_container = GenericContainer[str](value="hello", values=["a", "b", "c"]) 

846 serialized = str_container.serialize() 

847 loaded = GenericContainer[str].load(serialized) 

848 assert loaded == str_container 

849 

850 

851# Test custom serialization/deserialization 

852class CustomObject: 

853 def __init__(self, value): 

854 self.value = value 

855 

856 def __eq__(self, other): 

857 return isinstance(other, CustomObject) and self.value == other.value 

858 

859 

860@serializable_dataclass 

861class CustomSerializationContainer(SerializableDataclass): 

862 """Test custom serialization functions""" 

863 

864 custom_obj: CustomObject = serializable_field( 

865 serialization_fn=lambda x: x.value, 

866 loading_fn=lambda x: CustomObject(x["custom_obj"]), 

867 ) 

868 transform_field: int = serializable_field( 

869 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["transform_field"] // 2 

870 ) 

871 

872 

873def test_custom_serialization_2(): 

874 """Test custom serialization and loading functions""" 

875 instance = CustomSerializationContainer( 

876 custom_obj=CustomObject(42), transform_field=10 

877 ) 

878 

879 serialized = instance.serialize() 

880 assert serialized["custom_obj"] == 42 

881 assert serialized["transform_field"] == 20 

882 

883 loaded = CustomSerializationContainer.load(serialized) 

884 assert loaded == instance 

885 assert loaded.transform_field == 10 

886 

887 

888# @pytest.mark.skip(reason="Not implemented yet, waiting on `custom_value_check_fn`") 

889# def test_value_validation(): 

890# """Test field validation""" 

891# @serializable_dataclass 

892# class ValidationContainer(SerializableDataclass): 

893# """Test validation and error handling""" 

894# positive_int: int = serializable_field( 

895# custom_value_check_fn=lambda x: x > 0 

896# ) 

897# email: str = serializable_field( 

898# custom_value_check_fn=lambda x: '@' in x 

899# ) 

900 

901# # Valid case 

902# valid = ValidationContainer(positive_int=42, email="test@example.com") 

903# assert valid.validate_fields_types() 

904 

905# # what will this do? 

906# maybe_valid = ValidationContainer(positive_int=4.2, email="test@example.com") 

907# assert maybe_valid.validate_fields_types() 

908 

909# maybe_valid_2 = ValidationContainer(positive_int=42, email=["test", "@", "example", ".com"]) 

910# assert maybe_valid_2.validate_fields_types() 

911 

912# # Invalid positive_int 

913# with pytest.raises(ValueError): 

914# ValidationContainer(positive_int=-1, email="test@example.com") 

915 

916# # Invalid email 

917# with pytest.raises(ValueError): 

918# ValidationContainer(positive_int=42, email="invalid") 

919 

920 

921def test_init_true_serialize_false(): 

922 with pytest.raises(ValueError): 

923 

924 @serializable_dataclass 

925 class MetadataContainer(SerializableDataclass): 

926 """Test field metadata and options""" 

927 

928 hidden: str = serializable_field(serialize=False, init=True) 

929 readonly: int = serializable_field(init=True, frozen=True) 

930 computed: float = serializable_field(init=False, serialize=True) 

931 

932 def __post_init__(self): 

933 object.__setattr__(self, "computed", self.readonly * 2.0) 

934 

935 

936# Test property serialization 

937@serializable_dataclass(properties_to_serialize=["full_name", "age_in_months"]) 

938class PropertyContainer(SerializableDataclass): 

939 """Test property serialization""" 

940 

941 first_name: str 

942 last_name: str 

943 age_years: int 

944 

945 @property 

946 def full_name(self) -> str: 

947 return f"{self.first_name} {self.last_name}" 

948 

949 @property 

950 def age_in_months(self) -> int: 

951 return self.age_years * 12 

952 

953 

954def test_property_serialization(): 

955 """Test serialization of properties""" 

956 instance = PropertyContainer(first_name="John", last_name="Doe", age_years=30) 

957 

958 serialized = instance.serialize() 

959 assert serialized["full_name"] == "John Doe" 

960 assert serialized["age_in_months"] == 360 

961 

962 loaded = PropertyContainer.load(serialized) 

963 assert loaded == instance 

964 

965 

966# TODO: this would be nice to fix, but not a massive issue 

967@pytest.mark.skip(reason="Not implemented yet") 

968def test_edge_cases(): 

969 """Test a sdc containing instances of itself""" 

970 

971 @serializable_dataclass 

972 class EdgeCaseContainer(SerializableDataclass): 

973 """Test edge cases and corner cases""" 

974 

975 empty_list: List[Any] = serializable_field(default_factory=list) 

976 optional_value: Optional[int] = serializable_field(default=None) 

977 union_field: Union[str, int, None] = serializable_field(default=None) 

978 recursive_ref: Optional["EdgeCaseContainer"] = serializable_field(default=None) 

979 

980 # Test recursive structure 

981 nested = EdgeCaseContainer() 

982 instance = EdgeCaseContainer(recursive_ref=nested) 

983 

984 serialized = instance.serialize() 

985 loaded = EdgeCaseContainer.load(serialized) 

986 assert loaded == instance 

987 

988 # Test empty/None handling 

989 empty = EdgeCaseContainer() 

990 assert empty.empty_list == [] 

991 assert empty.optional_value is None 

992 assert empty.union_field is None 

993 

994 # Test union field with different types 

995 instance.union_field = "string" 

996 serialized = instance.serialize() 

997 loaded = EdgeCaseContainer.load(serialized) 

998 assert loaded.union_field == "string" 

999 

1000 instance.union_field = 42 

1001 serialized = instance.serialize() 

1002 loaded = EdgeCaseContainer.load(serialized) 

1003 assert loaded.union_field == 42 

1004 

1005 

1006# Test error handling for malformed data 

1007def test_error_handling(): 

1008 """Test error handling for malformed data""" 

1009 # Missing required field 

1010 with pytest.raises(TypeError): 

1011 BaseClass.load({}) 

1012 

1013 x = BaseClass(base_field=42, shared_field="invalid") 

1014 assert not x.validate_fields_types() 

1015 

1016 with pytest.raises(FieldTypeMismatchError): 

1017 BaseClass.load( 

1018 { 

1019 "base_field": 42, # Should be str 

1020 "shared_field": "invalid", # Should be int 

1021 } 

1022 ) 

1023 

1024 # Invalid format string 

1025 # with pytest.raises(ValueError): 

1026 # BaseClass.load({ 

1027 # _FORMAT_KEY: "InvalidClass(SerializableDataclass)", 

1028 # "base_field": "test", 

1029 # "shared_field": 0 

1030 # }) 

1031 

1032 

1033# Test for memory leaks and cyclic references 

1034# TODO: make .serialize() fail on cyclic references! see https://github.com/mivanit/muutils/issues/62 

1035@pytest.mark.skip(reason="Not implemented yet") 

1036def test_cyclic_references(): 

1037 """Test handling of cyclic references""" 

1038 

1039 @serializable_dataclass 

1040 class Node(SerializableDataclass): 

1041 value: str 

1042 next: Optional["Node"] = serializable_field(default=None) 

1043 

1044 # Create a cycle 

1045 node1 = Node("one") 

1046 node2 = Node("two") 

1047 node1.next = node2 

1048 node2.next = node1 

1049 

1050 # Ensure we can serialize without infinite recursion 

1051 serialized = node1.serialize() 

1052 loaded = Node.load(serialized) 

1053 assert loaded.value == "one" 

1054 assert loaded.next.value == "two"