Coverage for muutils\tensor_utils.py: 86%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-14 01:33 -0700

1"""utilities for working with tensors and arrays. 

2 

3notably: 

4 

5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 

6- `DTYPE_MAP` mapping string representations of types to their type 

7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 

8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 

9 

10""" 

11 

12from __future__ import annotations 

13 

14import json 

15import typing 

16 

17import jaxtyping 

18import numpy as np 

19import torch 

20 

21from muutils.errormode import ErrorMode 

22from muutils.dictmagic import dotlist_to_nested_dict 

23 

24# pylint: disable=missing-class-docstring 

25 

26 

27TYPE_TO_JAX_DTYPE: dict = { 

28 float: jaxtyping.Float, 

29 int: jaxtyping.Int, 

30 jaxtyping.Float: jaxtyping.Float, 

31 jaxtyping.Int: jaxtyping.Int, 

32 # bool 

33 bool: jaxtyping.Bool, 

34 jaxtyping.Bool: jaxtyping.Bool, 

35 np.bool_: jaxtyping.Bool, 

36 torch.bool: jaxtyping.Bool, 

37 # numpy float 

38 np.float16: jaxtyping.Float, 

39 np.float32: jaxtyping.Float, 

40 np.float64: jaxtyping.Float, 

41 np.half: jaxtyping.Float, 

42 np.single: jaxtyping.Float, 

43 np.double: jaxtyping.Float, 

44 # numpy int 

45 np.int8: jaxtyping.Int, 

46 np.int16: jaxtyping.Int, 

47 np.int32: jaxtyping.Int, 

48 np.int64: jaxtyping.Int, 

49 np.longlong: jaxtyping.Int, 

50 np.short: jaxtyping.Int, 

51 np.uint8: jaxtyping.Int, 

52 # torch float 

53 torch.float: jaxtyping.Float, 

54 torch.float16: jaxtyping.Float, 

55 torch.float32: jaxtyping.Float, 

56 torch.float64: jaxtyping.Float, 

57 torch.half: jaxtyping.Float, 

58 torch.double: jaxtyping.Float, 

59 torch.bfloat16: jaxtyping.Float, 

60 # torch int 

61 torch.int: jaxtyping.Int, 

62 torch.int8: jaxtyping.Int, 

63 torch.int16: jaxtyping.Int, 

64 torch.int32: jaxtyping.Int, 

65 torch.int64: jaxtyping.Int, 

66 torch.long: jaxtyping.Int, 

67 torch.short: jaxtyping.Int, 

68} 

69"dict mapping python, numpy, and torch types to `jaxtyping` types" 

70 

71if np.version.version < "2.0.0": 

72 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float 

73 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int 

74 

75 

76# TODO: add proper type annotations to this signature 

77def jaxtype_factory( 

78 name: str, 

79 array_type: type, 

80 default_jax_dtype=jaxtyping.Float, 

81 legacy_mode: ErrorMode = ErrorMode.WARN, 

82) -> type: 

83 """usage: 

84 ``` 

85 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 

86 x: ATensor["dim1 dim2", np.float32] 

87 ``` 

88 """ 

89 legacy_mode = ErrorMode.from_any(legacy_mode) 

90 

91 class _BaseArray: 

92 """jaxtyping shorthand 

93 (backwards compatible with older versions of muutils.tensor_utils) 

94 

95 default_jax_dtype = {default_jax_dtype} 

96 array_type = {array_type} 

97 """ 

98 

99 def __new__(cls, *args, **kwargs): 

100 raise TypeError("Type FArray cannot be instantiated.") 

101 

102 def __init_subclass__(cls, *args, **kwargs): 

103 raise TypeError(f"Cannot subclass {cls.__name__}") 

104 

105 @classmethod 

106 def param_info(cls, params) -> str: 

107 """useful for error printing""" 

108 return "\n".join( 

109 f"{k} = {v}" 

110 for k, v in { 

111 "cls.__name__": cls.__name__, 

112 "cls.__doc__": cls.__doc__, 

113 "params": params, 

114 "type(params)": type(params), 

115 }.items() 

116 ) 

117 

118 @typing._tp_cache # type: ignore 

119 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 

120 # MyTensor["dim1 dim2"] 

121 if isinstance(params, str): 

122 return default_jax_dtype[array_type, params] 

123 

124 elif isinstance(params, tuple): 

125 if len(params) != 2: 

126 raise Exception( 

127 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 

128 ) 

129 

130 if isinstance(params[0], str): 

131 # MyTensor["dim1 dim2", int] 

132 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 

133 

134 elif isinstance(params[0], tuple): 

135 legacy_mode.process( 

136 f"legacy type annotation was used:\n{cls.param_info(params) = }", 

137 except_cls=Exception, 

138 ) 

139 # MyTensor[("dim1", "dim2"), int] 

140 shape_anot: list[str] = list() 

141 for x in params[0]: 

142 if isinstance(x, str): 

143 shape_anot.append(x) 

144 elif isinstance(x, int): 

145 shape_anot.append(str(x)) 

146 elif isinstance(x, tuple): 

147 shape_anot.append("".join(str(y) for y in x)) 

148 else: 

149 raise Exception( 

150 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 

151 ) 

152 

153 return TYPE_TO_JAX_DTYPE[params[1]][ 

154 array_type, " ".join(shape_anot) 

155 ] 

156 else: 

157 raise Exception( 

158 f"unexpected type for params:\n{cls.param_info(params)}" 

159 ) 

160 

161 _BaseArray.__name__ = name 

162 

163 if _BaseArray.__doc__ is None: 

164 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 

165 

166 _BaseArray.__doc__ = _BaseArray.__doc__.format( 

167 default_jax_dtype=repr(default_jax_dtype), 

168 array_type=repr(array_type), 

169 ) 

170 

171 return _BaseArray 

172 

173 

174if typing.TYPE_CHECKING: 

175 # these class definitions are only used here to make pylint happy, 

176 # but they make mypy unhappy and there is no way to only run if not mypy 

177 # so, later on we have more ignores 

178 class ATensor(torch.Tensor): 

179 @typing._tp_cache # type: ignore 

180 def __class_getitem__(cls, params): 

181 raise NotImplementedError() 

182 

183 class NDArray(torch.Tensor): 

184 @typing._tp_cache # type: ignore 

185 def __class_getitem__(cls, params): 

186 raise NotImplementedError() 

187 

188 

189ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 

190 

191NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 

192 

193 

194def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 

195 """convert numpy dtype to torch dtype""" 

196 if isinstance(dtype, torch.dtype): 

197 return dtype 

198 else: 

199 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 

200 

201 

202DTYPE_LIST: list = [ 

203 *[ 

204 bool, 

205 int, 

206 float, 

207 ], 

208 *[ 

209 # ---------- 

210 # pytorch 

211 # ---------- 

212 # floats 

213 torch.float, 

214 torch.float32, 

215 torch.float64, 

216 torch.half, 

217 torch.double, 

218 torch.bfloat16, 

219 # complex 

220 torch.complex64, 

221 torch.complex128, 

222 # ints 

223 torch.int, 

224 torch.int8, 

225 torch.int16, 

226 torch.int32, 

227 torch.int64, 

228 torch.long, 

229 torch.short, 

230 # simplest 

231 torch.uint8, 

232 torch.bool, 

233 ], 

234 *[ 

235 # ---------- 

236 # numpy 

237 # ---------- 

238 # floats 

239 np.float16, 

240 np.float32, 

241 np.float64, 

242 np.half, 

243 np.single, 

244 np.double, 

245 # complex 

246 np.complex64, 

247 np.complex128, 

248 # ints 

249 np.int8, 

250 np.int16, 

251 np.int32, 

252 np.int64, 

253 np.longlong, 

254 np.short, 

255 # simplest 

256 np.uint8, 

257 np.bool_, 

258 ], 

259] 

260"list of all the python, numpy, and torch numerical types I could think of" 

261 

262if np.version.version < "2.0.0": 

263 DTYPE_LIST.extend([np.float_, np.int_]) 

264 

265DTYPE_MAP: dict = { 

266 **{str(x): x for x in DTYPE_LIST}, 

267 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 

268} 

269"mapping from string representations of types to their type" 

270 

271TORCH_DTYPE_MAP: dict = { 

272 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 

273} 

274"mapping from string representations of types to specifically torch types" 

275 

276# no idea why we have to do this, smh 

277DTYPE_MAP["bool"] = np.bool_ 

278TORCH_DTYPE_MAP["bool"] = torch.bool 

279 

280 

281TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 

282 "Adagrad": torch.optim.Adagrad, 

283 "Adam": torch.optim.Adam, 

284 "AdamW": torch.optim.AdamW, 

285 "SparseAdam": torch.optim.SparseAdam, 

286 "Adamax": torch.optim.Adamax, 

287 "ASGD": torch.optim.ASGD, 

288 "LBFGS": torch.optim.LBFGS, 

289 "NAdam": torch.optim.NAdam, 

290 "RAdam": torch.optim.RAdam, 

291 "RMSprop": torch.optim.RMSprop, 

292 "Rprop": torch.optim.Rprop, 

293 "SGD": torch.optim.SGD, 

294} 

295 

296 

297def pad_tensor( 

298 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 

299 padded_length: int, 

300 pad_value: float = 0.0, 

301 rpad: bool = False, 

302) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 

303 """pad a 1-d tensor on the left with pad_value to length `padded_length` 

304 

305 set `rpad = True` to pad on the right instead""" 

306 

307 temp: list[torch.Tensor] = [ 

308 torch.full( 

309 (padded_length - tensor.shape[0],), 

310 pad_value, 

311 dtype=tensor.dtype, 

312 device=tensor.device, 

313 ), 

314 tensor, 

315 ] 

316 

317 if rpad: 

318 temp.reverse() 

319 

320 return torch.cat(temp) 

321 

322 

323def lpad_tensor( 

324 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 

325) -> torch.Tensor: 

326 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 

327 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 

328 

329 

330def rpad_tensor( 

331 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 

332) -> torch.Tensor: 

333 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 

334 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 

335 

336 

337def pad_array( 

338 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 

339 padded_length: int, 

340 pad_value: float = 0.0, 

341 rpad: bool = False, 

342) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 

343 """pad a 1-d array on the left with pad_value to length `padded_length` 

344 

345 set `rpad = True` to pad on the right instead""" 

346 

347 temp: list[np.ndarray] = [ 

348 np.full( 

349 (padded_length - array.shape[0],), 

350 pad_value, 

351 dtype=array.dtype, 

352 ), 

353 array, 

354 ] 

355 

356 if rpad: 

357 temp.reverse() 

358 

359 return np.concatenate(temp) 

360 

361 

362def lpad_array( 

363 array: np.ndarray, padded_length: int, pad_value: float = 0.0 

364) -> np.ndarray: 

365 """pad a 1-d array on the left with pad_value to length `padded_length`""" 

366 return pad_array(array, padded_length, pad_value, rpad=False) 

367 

368 

369def rpad_array( 

370 array: np.ndarray, pad_length: int, pad_value: float = 0.0 

371) -> np.ndarray: 

372 """pad a 1-d array on the right with pad_value to length `pad_length`""" 

373 return pad_array(array, pad_length, pad_value, rpad=True) 

374 

375 

376def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 

377 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 

378 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 

379 

380 

381def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 

382 """printable version of get_dict_shapes""" 

383 return json.dumps( 

384 dotlist_to_nested_dict( 

385 { 

386 k: str( 

387 tuple(v.shape) 

388 ) # to string, since indent wont play nice with tuples 

389 for k, v in d.items() 

390 } 

391 ), 

392 indent=2, 

393 ) 

394 

395 

396class StateDictCompareError(AssertionError): 

397 """raised when state dicts don't match""" 

398 

399 pass 

400 

401 

402class StateDictKeysError(StateDictCompareError): 

403 """raised when state dict keys don't match""" 

404 

405 pass 

406 

407 

408class StateDictShapeError(StateDictCompareError): 

409 """raised when state dict shapes don't match""" 

410 

411 pass 

412 

413 

414class StateDictValueError(StateDictCompareError): 

415 """raised when state dict values don't match""" 

416 

417 pass 

418 

419 

420def compare_state_dicts( 

421 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 

422) -> None: 

423 """compare two dicts of tensors 

424 

425 # Parameters: 

426 

427 - `d1 : dict` 

428 - `d2 : dict` 

429 - `rtol : float` 

430 (defaults to `1e-5`) 

431 - `atol : float` 

432 (defaults to `1e-8`) 

433 - `verbose : bool` 

434 (defaults to `True`) 

435 

436 # Raises: 

437 

438 - `StateDictKeysError` : keys don't match 

439 - `StateDictShapeError` : shapes don't match (but keys do) 

440 - `StateDictValueError` : values don't match (but keys and shapes do) 

441 """ 

442 # check keys match 

443 d1_keys: set = set(d1.keys()) 

444 d2_keys: set = set(d2.keys()) 

445 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 

446 keys_diff_1: set = d1_keys - d2_keys 

447 keys_diff_2: set = d2_keys - d1_keys 

448 # sort sets for easier debugging 

449 symmetric_diff = set(sorted(symmetric_diff)) 

450 keys_diff_1 = set(sorted(keys_diff_1)) 

451 keys_diff_2 = set(sorted(keys_diff_2)) 

452 diff_shapes_1: str = ( 

453 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 

454 if verbose 

455 else "(verbose = False)" 

456 ) 

457 diff_shapes_2: str = ( 

458 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 

459 if verbose 

460 else "(verbose = False)" 

461 ) 

462 if not len(symmetric_diff) == 0: 

463 raise StateDictKeysError( 

464 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 

465 ) 

466 

467 # check tensors match 

468 shape_failed: list[str] = list() 

469 vals_failed: list[str] = list() 

470 for k, v1 in d1.items(): 

471 v2 = d2[k] 

472 # check shapes first 

473 if not v1.shape == v2.shape: 

474 shape_failed.append(k) 

475 else: 

476 # if shapes match, check values 

477 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 

478 vals_failed.append(k) 

479 

480 str_shape_failed: str = ( 

481 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 

482 ) 

483 str_vals_failed: str = ( 

484 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 

485 ) 

486 

487 if not len(shape_failed) == 0: 

488 raise StateDictShapeError( 

489 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 

490 ) 

491 if not len(vals_failed) == 0: 

492 raise StateDictValueError( 

493 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 

494 )