Coverage for muutils\tensor_utils.py: 86%
128 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
1"""utilities for working with tensors and arrays.
3notably:
5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
6- `DTYPE_MAP` mapping string representations of types to their type
7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
10"""
12from __future__ import annotations
14import json
15import typing
17import jaxtyping
18import numpy as np
19import torch
21from muutils.errormode import ErrorMode
22from muutils.dictmagic import dotlist_to_nested_dict
24# pylint: disable=missing-class-docstring
27TYPE_TO_JAX_DTYPE: dict = {
28 float: jaxtyping.Float,
29 int: jaxtyping.Int,
30 jaxtyping.Float: jaxtyping.Float,
31 jaxtyping.Int: jaxtyping.Int,
32 # bool
33 bool: jaxtyping.Bool,
34 jaxtyping.Bool: jaxtyping.Bool,
35 np.bool_: jaxtyping.Bool,
36 torch.bool: jaxtyping.Bool,
37 # numpy float
38 np.float_: jaxtyping.Float,
39 np.float16: jaxtyping.Float,
40 np.float32: jaxtyping.Float,
41 np.float64: jaxtyping.Float,
42 np.half: jaxtyping.Float,
43 np.single: jaxtyping.Float,
44 np.double: jaxtyping.Float,
45 # numpy int
46 np.int_: jaxtyping.Int,
47 np.int8: jaxtyping.Int,
48 np.int16: jaxtyping.Int,
49 np.int32: jaxtyping.Int,
50 np.int64: jaxtyping.Int,
51 np.longlong: jaxtyping.Int,
52 np.short: jaxtyping.Int,
53 np.uint8: jaxtyping.Int,
54 # torch float
55 torch.float: jaxtyping.Float,
56 torch.float16: jaxtyping.Float,
57 torch.float32: jaxtyping.Float,
58 torch.float64: jaxtyping.Float,
59 torch.half: jaxtyping.Float,
60 torch.double: jaxtyping.Float,
61 torch.bfloat16: jaxtyping.Float,
62 # torch int
63 torch.int: jaxtyping.Int,
64 torch.int8: jaxtyping.Int,
65 torch.int16: jaxtyping.Int,
66 torch.int32: jaxtyping.Int,
67 torch.int64: jaxtyping.Int,
68 torch.long: jaxtyping.Int,
69 torch.short: jaxtyping.Int,
70}
71"dict mapping python, numpy, and torch types to `jaxtyping` types"
74# TODO: add proper type annotations to this signature
75def jaxtype_factory(
76 name: str,
77 array_type: type,
78 default_jax_dtype=jaxtyping.Float,
79 legacy_mode: ErrorMode = ErrorMode.WARN,
80) -> type:
81 """usage:
82 ```
83 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
84 x: ATensor["dim1 dim2", np.float32]
85 ```
86 """
87 legacy_mode = ErrorMode.from_any(legacy_mode)
89 class _BaseArray:
90 """jaxtyping shorthand
91 (backwards compatible with older versions of muutils.tensor_utils)
93 default_jax_dtype = {default_jax_dtype}
94 array_type = {array_type}
95 """
97 def __new__(cls, *args, **kwargs):
98 raise TypeError("Type FArray cannot be instantiated.")
100 def __init_subclass__(cls, *args, **kwargs):
101 raise TypeError(f"Cannot subclass {cls.__name__}")
103 @classmethod
104 def param_info(cls, params) -> str:
105 """useful for error printing"""
106 return "\n".join(
107 f"{k} = {v}"
108 for k, v in {
109 "cls.__name__": cls.__name__,
110 "cls.__doc__": cls.__doc__,
111 "params": params,
112 "type(params)": type(params),
113 }.items()
114 )
116 @typing._tp_cache # type: ignore
117 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:
118 # MyTensor["dim1 dim2"]
119 if isinstance(params, str):
120 return default_jax_dtype[array_type, params]
122 elif isinstance(params, tuple):
123 if len(params) != 2:
124 raise Exception(
125 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
126 )
128 if isinstance(params[0], str):
129 # MyTensor["dim1 dim2", int]
130 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
132 elif isinstance(params[0], tuple):
133 legacy_mode.process(
134 f"legacy type annotation was used:\n{cls.param_info(params) = }",
135 except_cls=Exception,
136 )
137 # MyTensor[("dim1", "dim2"), int]
138 shape_anot: list[str] = list()
139 for x in params[0]:
140 if isinstance(x, str):
141 shape_anot.append(x)
142 elif isinstance(x, int):
143 shape_anot.append(str(x))
144 elif isinstance(x, tuple):
145 shape_anot.append("".join(str(y) for y in x))
146 else:
147 raise Exception(
148 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
149 )
151 return TYPE_TO_JAX_DTYPE[params[1]][
152 array_type, " ".join(shape_anot)
153 ]
154 else:
155 raise Exception(
156 f"unexpected type for params:\n{cls.param_info(params)}"
157 )
159 _BaseArray.__name__ = name
161 if _BaseArray.__doc__ is None:
162 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
164 _BaseArray.__doc__ = _BaseArray.__doc__.format(
165 default_jax_dtype=repr(default_jax_dtype),
166 array_type=repr(array_type),
167 )
169 return _BaseArray
172if typing.TYPE_CHECKING:
173 # these class definitions are only used here to make pylint happy,
174 # but they make mypy unhappy and there is no way to only run if not mypy
175 # so, later on we have more ignores
176 class ATensor(torch.Tensor):
177 @typing._tp_cache # type: ignore
178 def __class_getitem__(cls, params):
179 raise NotImplementedError()
181 class NDArray(torch.Tensor):
182 @typing._tp_cache # type: ignore
183 def __class_getitem__(cls, params):
184 raise NotImplementedError()
187ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment]
189NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment]
192def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
193 """convert numpy dtype to torch dtype"""
194 if isinstance(dtype, torch.dtype):
195 return dtype
196 else:
197 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
200DTYPE_LIST: list = [
201 *[
202 bool,
203 int,
204 float,
205 ],
206 *[
207 # ----------
208 # pytorch
209 # ----------
210 # floats
211 torch.float,
212 torch.float32,
213 torch.float64,
214 torch.half,
215 torch.double,
216 torch.bfloat16,
217 # complex
218 torch.complex64,
219 torch.complex128,
220 # ints
221 torch.int,
222 torch.int8,
223 torch.int16,
224 torch.int32,
225 torch.int64,
226 torch.long,
227 torch.short,
228 # simplest
229 torch.uint8,
230 torch.bool,
231 ],
232 *[
233 # ----------
234 # numpy
235 # ----------
236 # floats
237 np.float_,
238 np.float16,
239 np.float32,
240 np.float64,
241 np.half,
242 np.single,
243 np.double,
244 # complex
245 np.complex64,
246 np.complex128,
247 # ints
248 np.int8,
249 np.int16,
250 np.int32,
251 np.int64,
252 np.int_,
253 np.longlong,
254 np.short,
255 # simplest
256 np.uint8,
257 np.bool_,
258 ],
259]
260"list of all the python, numpy, and torch numerical types I could think of"
262DTYPE_MAP: dict = {
263 **{str(x): x for x in DTYPE_LIST},
264 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
265}
266"mapping from string representations of types to their type"
268TORCH_DTYPE_MAP: dict = {
269 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
270}
271"mapping from string representations of types to specifically torch types"
273# no idea why we have to do this, smh
274DTYPE_MAP["bool"] = np.bool_
275TORCH_DTYPE_MAP["bool"] = torch.bool
278TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
279 "Adagrad": torch.optim.Adagrad,
280 "Adam": torch.optim.Adam,
281 "AdamW": torch.optim.AdamW,
282 "SparseAdam": torch.optim.SparseAdam,
283 "Adamax": torch.optim.Adamax,
284 "ASGD": torch.optim.ASGD,
285 "LBFGS": torch.optim.LBFGS,
286 "NAdam": torch.optim.NAdam,
287 "RAdam": torch.optim.RAdam,
288 "RMSprop": torch.optim.RMSprop,
289 "Rprop": torch.optim.Rprop,
290 "SGD": torch.optim.SGD,
291}
294def pad_tensor(
295 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821
296 padded_length: int,
297 pad_value: float = 0.0,
298 rpad: bool = False,
299) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821
300 """pad a 1-d tensor on the left with pad_value to length `padded_length`
302 set `rpad = True` to pad on the right instead"""
304 temp: list[torch.Tensor] = [
305 torch.full(
306 (padded_length - tensor.shape[0],),
307 pad_value,
308 dtype=tensor.dtype,
309 device=tensor.device,
310 ),
311 tensor,
312 ]
314 if rpad:
315 temp.reverse()
317 return torch.cat(temp)
320def lpad_tensor(
321 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
322) -> torch.Tensor:
323 """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
324 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
327def rpad_tensor(
328 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
329) -> torch.Tensor:
330 """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
331 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
334def pad_array(
335 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821
336 padded_length: int,
337 pad_value: float = 0.0,
338 rpad: bool = False,
339) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821
340 """pad a 1-d array on the left with pad_value to length `padded_length`
342 set `rpad = True` to pad on the right instead"""
344 temp: list[np.ndarray] = [
345 np.full(
346 (padded_length - array.shape[0],),
347 pad_value,
348 dtype=array.dtype,
349 ),
350 array,
351 ]
353 if rpad:
354 temp.reverse()
356 return np.concatenate(temp)
359def lpad_array(
360 array: np.ndarray, padded_length: int, pad_value: float = 0.0
361) -> np.ndarray:
362 """pad a 1-d array on the left with pad_value to length `padded_length`"""
363 return pad_array(array, padded_length, pad_value, rpad=False)
366def rpad_array(
367 array: np.ndarray, pad_length: int, pad_value: float = 0.0
368) -> np.ndarray:
369 """pad a 1-d array on the right with pad_value to length `pad_length`"""
370 return pad_array(array, pad_length, pad_value, rpad=True)
373def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
374 """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
375 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
378def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
379 """printable version of get_dict_shapes"""
380 return json.dumps(
381 dotlist_to_nested_dict(
382 {
383 k: str(
384 tuple(v.shape)
385 ) # to string, since indent wont play nice with tuples
386 for k, v in d.items()
387 }
388 ),
389 indent=2,
390 )
393class StateDictCompareError(AssertionError):
394 """raised when state dicts don't match"""
396 pass
399class StateDictKeysError(StateDictCompareError):
400 """raised when state dict keys don't match"""
402 pass
405class StateDictShapeError(StateDictCompareError):
406 """raised when state dict shapes don't match"""
408 pass
411class StateDictValueError(StateDictCompareError):
412 """raised when state dict values don't match"""
414 pass
417def compare_state_dicts(
418 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
419) -> None:
420 """compare two dicts of tensors
422 # Parameters:
424 - `d1 : dict`
425 - `d2 : dict`
426 - `rtol : float`
427 (defaults to `1e-5`)
428 - `atol : float`
429 (defaults to `1e-8`)
430 - `verbose : bool`
431 (defaults to `True`)
433 # Raises:
435 - `StateDictKeysError` : keys don't match
436 - `StateDictShapeError` : shapes don't match (but keys do)
437 - `StateDictValueError` : values don't match (but keys and shapes do)
438 """
439 # check keys match
440 d1_keys: set = set(d1.keys())
441 d2_keys: set = set(d2.keys())
442 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
443 keys_diff_1: set = d1_keys - d2_keys
444 keys_diff_2: set = d2_keys - d1_keys
445 # sort sets for easier debugging
446 symmetric_diff = set(sorted(symmetric_diff))
447 keys_diff_1 = set(sorted(keys_diff_1))
448 keys_diff_2 = set(sorted(keys_diff_2))
449 diff_shapes_1: str = (
450 string_dict_shapes({k: d1[k] for k in keys_diff_1})
451 if verbose
452 else "(verbose = False)"
453 )
454 diff_shapes_2: str = (
455 string_dict_shapes({k: d2[k] for k in keys_diff_2})
456 if verbose
457 else "(verbose = False)"
458 )
459 if not len(symmetric_diff) == 0:
460 raise StateDictKeysError(
461 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
462 )
464 # check tensors match
465 shape_failed: list[str] = list()
466 vals_failed: list[str] = list()
467 for k, v1 in d1.items():
468 v2 = d2[k]
469 # check shapes first
470 if not v1.shape == v2.shape:
471 shape_failed.append(k)
472 else:
473 # if shapes match, check values
474 if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
475 vals_failed.append(k)
477 str_shape_failed: str = (
478 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
479 )
480 str_vals_failed: str = (
481 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
482 )
484 if not len(shape_failed) == 0:
485 raise StateDictShapeError(
486 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
487 )
488 if not len(vals_failed) == 0:
489 raise StateDictValueError(
490 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
491 )