Coverage for muutils\json_serialize\json_serialize.py: 28%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-05 19:24 -0700

1"""provides the basic framework for json serialization of objects 

2 

3notably: 

4 

5- `SerializerHandler` defines how to serialize a specific type of object 

6- `JsonSerializer` handles configuration for which handlers to use 

7- `json_serialize` provides the default configuration if you don't care -- call it on any object! 

8 

9""" 

10 

11from __future__ import annotations 

12 

13import inspect 

14import warnings 

15from dataclasses import dataclass, is_dataclass 

16from pathlib import Path 

17from typing import Any, Callable, Iterable, Mapping, Set, Union 

18 

19from muutils.errormode import ErrorMode 

20 

21try: 

22 from muutils.json_serialize.array import ArrayMode, serialize_array 

23except ImportError as e: 

24 ArrayMode = str # type: ignore[misc] 

25 serialize_array = lambda *args, **kwargs: None # noqa: E731 

26 warnings.warn( 

27 f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", 

28 ImportWarning, 

29 ) 

30 

31from muutils.json_serialize.util import ( 

32 Hashableitem, 

33 JSONitem, 

34 MonoTuple, 

35 SerializationException, 

36 _recursive_hashify, 

37 isinstance_namedtuple, 

38 safe_getsource, 

39 string_as_lines, 

40 try_catch, 

41) 

42 

43# pylint: disable=protected-access 

44 

45SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = ( 

46 "__name__", 

47 "__doc__", 

48 "__module__", 

49 "__class__", 

50 "__dict__", 

51 "__annotations__", 

52) 

53 

54SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = { 

55 "str": str, 

56 "dir": dir, 

57 "type": try_catch(lambda x: str(type(x).__name__)), 

58 "repr": try_catch(lambda x: repr(x)), 

59 "code": try_catch(lambda x: inspect.getsource(x)), 

60 "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)), 

61} 

62 

63SERIALIZE_DIRECT_AS_STR: Set[str] = { 

64 "<class 'torch.device'>", 

65 "<class 'torch.dtype'>", 

66} 

67 

68ObjectPath = MonoTuple[Union[str, int]] 

69 

70 

71@dataclass 

72class SerializerHandler: 

73 """a handler for a specific type of object 

74 

75 # Parameters: 

76 - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler 

77 - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object 

78 - `desc : str` description of the handler (optional) 

79 """ 

80 

81 # (self_config, object) -> whether to use this handler 

82 check: Callable[["JsonSerializer", Any, ObjectPath], bool] 

83 # (self_config, object, path) -> serialized object 

84 serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem] 

85 # unique identifier for the handler 

86 uid: str 

87 # description of this serializer 

88 desc: str 

89 

90 def serialize(self) -> dict: 

91 """serialize the handler info""" 

92 return { 

93 # get the code and doc of the check function 

94 "check": { 

95 "code": safe_getsource(self.check), 

96 "doc": string_as_lines(self.check.__doc__), 

97 }, 

98 # get the code and doc of the load function 

99 "serialize_func": { 

100 "code": safe_getsource(self.serialize_func), 

101 "doc": string_as_lines(self.serialize_func.__doc__), 

102 }, 

103 # get the uid, source_pckg, priority, and desc 

104 "uid": str(self.uid), 

105 "source_pckg": getattr(self.serialize_func, "source_pckg", None), 

106 "__module__": getattr(self.serialize_func, "__module__", None), 

107 "desc": str(self.desc), 

108 } 

109 

110 

111BASE_HANDLERS: MonoTuple[SerializerHandler] = ( 

112 SerializerHandler( 

113 check=lambda self, obj, path: isinstance( 

114 obj, (bool, int, float, str, type(None)) 

115 ), 

116 serialize_func=lambda self, obj, path: obj, 

117 uid="base types", 

118 desc="base types (bool, int, float, str, None)", 

119 ), 

120 SerializerHandler( 

121 check=lambda self, obj, path: isinstance(obj, Mapping), 

122 serialize_func=lambda self, obj, path: { 

123 str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items() 

124 }, 

125 uid="dictionaries", 

126 desc="dictionaries", 

127 ), 

128 SerializerHandler( 

129 check=lambda self, obj, path: isinstance(obj, (list, tuple)), 

130 serialize_func=lambda self, obj, path: [ 

131 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) 

132 ], 

133 uid="(list, tuple) -> list", 

134 desc="lists and tuples as lists", 

135 ), 

136) 

137 

138 

139def _serialize_override_serialize_func( 

140 self: "JsonSerializer", obj: Any, path: ObjectPath 

141) -> JSONitem: 

142 # obj_cls: type = type(obj) 

143 # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self): 

144 # obj_cls._register_self() 

145 

146 # get the serialized object 

147 return obj.serialize() 

148 

149 

150DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + ( 

151 SerializerHandler( 

152 # TODO: allow for custom serialization handler name 

153 check=lambda self, obj, path: hasattr(obj, "serialize") 

154 and callable(obj.serialize), 

155 serialize_func=_serialize_override_serialize_func, 

156 uid=".serialize override", 

157 desc="objects with .serialize method", 

158 ), 

159 SerializerHandler( 

160 check=lambda self, obj, path: isinstance_namedtuple(obj), 

161 serialize_func=lambda self, obj, path: self.json_serialize(dict(obj._asdict())), 

162 uid="namedtuple -> dict", 

163 desc="namedtuples as dicts", 

164 ), 

165 SerializerHandler( 

166 check=lambda self, obj, path: is_dataclass(obj), 

167 serialize_func=lambda self, obj, path: { 

168 k: self.json_serialize(getattr(obj, k), tuple(path) + (k,)) 

169 for k in obj.__dataclass_fields__ 

170 }, 

171 uid="dataclass -> dict", 

172 desc="dataclasses as dicts", 

173 ), 

174 SerializerHandler( 

175 check=lambda self, obj, path: isinstance(obj, Path), 

176 serialize_func=lambda self, obj, path: obj.as_posix(), 

177 uid="path -> str", 

178 desc="Path objects as posix strings", 

179 ), 

180 SerializerHandler( 

181 check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR, 

182 serialize_func=lambda self, obj, path: str(obj), 

183 uid="obj -> str(obj)", 

184 desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings", 

185 ), 

186 SerializerHandler( 

187 check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>", 

188 serialize_func=lambda self, obj, path: serialize_array(self, obj, path=path), 

189 uid="numpy.ndarray", 

190 desc="numpy arrays", 

191 ), 

192 SerializerHandler( 

193 check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>", 

194 serialize_func=lambda self, obj, path: serialize_array( 

195 self, obj.detach().cpu(), path=path 

196 ), 

197 uid="torch.Tensor", 

198 desc="pytorch tensors", 

199 ), 

200 SerializerHandler( 

201 check=lambda self, obj, path: str(type(obj)) 

202 == "<class 'pandas.core.frame.DataFrame'>", 

203 serialize_func=lambda self, obj, path: dict( 

204 __format__="pandas.DataFrame", 

205 columns=obj.columns.tolist(), 

206 data=obj.to_dict(orient="records"), 

207 path=path, 

208 ), 

209 uid="pandas.DataFrame", 

210 desc="pandas DataFrames", 

211 ), 

212 SerializerHandler( 

213 check=lambda self, obj, path: isinstance(obj, (set, list, tuple)) 

214 or isinstance(obj, Iterable), 

215 serialize_func=lambda self, obj, path: [ 

216 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) 

217 ], 

218 uid="(set, list, tuple, Iterable) -> list", 

219 desc="sets, lists, tuples, and Iterables as lists", 

220 ), 

221 SerializerHandler( 

222 check=lambda self, obj, path: True, 

223 serialize_func=lambda self, obj, path: { 

224 **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS}, 

225 **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()}, 

226 }, 

227 uid="fallback", 

228 desc="fallback handler -- serialize object attributes and special functions as strings", 

229 ), 

230) 

231 

232 

233class JsonSerializer: 

234 """Json serialization class (holds configs) 

235 

236 # Parameters: 

237 - `array_mode : ArrayMode` 

238 how to write arrays 

239 (defaults to `"array_list_meta"`) 

240 - `error_mode : ErrorMode` 

241 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 

242 (defaults to `"except"`) 

243 - `handlers_pre : MonoTuple[SerializerHandler]` 

244 handlers to use before the default handlers 

245 (defaults to `tuple()`) 

246 - `handlers_default : MonoTuple[SerializerHandler]` 

247 default handlers to use 

248 (defaults to `DEFAULT_HANDLERS`) 

249 - `write_only_format : bool` 

250 changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 

251 (defaults to `False`) 

252 

253 # Raises: 

254 - `ValueError`: on init, if `args` is not empty 

255 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 

256 

257 """ 

258 

259 def __init__( 

260 self, 

261 *args, 

262 array_mode: ArrayMode = "array_list_meta", 

263 error_mode: ErrorMode = ErrorMode.EXCEPT, 

264 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 

265 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 

266 write_only_format: bool = False, 

267 ): 

268 if len(args) > 0: 

269 raise ValueError( 

270 f"JsonSerializer takes no positional arguments!\n{args = }" 

271 ) 

272 

273 self.array_mode: ArrayMode = array_mode 

274 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 

275 self.write_only_format: bool = write_only_format 

276 # join up the handlers 

277 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 

278 handlers_default 

279 ) 

280 

281 def json_serialize( 

282 self, 

283 obj: Any, 

284 path: ObjectPath = tuple(), 

285 ) -> JSONitem: 

286 try: 

287 for handler in self.handlers: 

288 if handler.check(self, obj, path): 

289 output: JSONitem = handler.serialize_func(self, obj, path) 

290 if self.write_only_format: 

291 if isinstance(output, dict) and "__format__" in output: 

292 new_fmt: JSONitem = output.pop("__format__") 

293 output["__write_format__"] = new_fmt 

294 return output 

295 

296 raise ValueError(f"no handler found for object with {type(obj) = }") 

297 

298 except Exception as e: 

299 if self.error_mode == "except": 

300 obj_str: str = repr(obj) 

301 if len(obj_str) > 1000: 

302 obj_str = obj_str[:1000] + "..." 

303 raise SerializationException( 

304 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 

305 ) from e 

306 elif self.error_mode == "warn": 

307 warnings.warn( 

308 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 

309 ) 

310 

311 return repr(obj) 

312 

313 def hashify( 

314 self, 

315 obj: Any, 

316 path: ObjectPath = tuple(), 

317 force: bool = True, 

318 ) -> Hashableitem: 

319 """try to turn any object into something hashable""" 

320 data = self.json_serialize(obj, path=path) 

321 

322 # recursive hashify, turning dicts and lists into tuples 

323 return _recursive_hashify(data, force=force) 

324 

325 

326GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer() 

327 

328 

329def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: 

330 """serialize object to json-serializable object with default config""" 

331 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)