Coverage for muutils\tensor_utils.py: 86%
133 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-14 01:33 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-14 01:33 -0700
1"""utilities for working with tensors and arrays.
3notably:
5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
6- `DTYPE_MAP` mapping string representations of types to their type
7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
10"""
12from __future__ import annotations
14import json
15import typing
17import jaxtyping
18import numpy as np
19import torch
21from muutils.errormode import ErrorMode
22from muutils.dictmagic import dotlist_to_nested_dict
24# pylint: disable=missing-class-docstring
27TYPE_TO_JAX_DTYPE: dict = {
28 float: jaxtyping.Float,
29 int: jaxtyping.Int,
30 jaxtyping.Float: jaxtyping.Float,
31 jaxtyping.Int: jaxtyping.Int,
32 # bool
33 bool: jaxtyping.Bool,
34 jaxtyping.Bool: jaxtyping.Bool,
35 np.bool_: jaxtyping.Bool,
36 torch.bool: jaxtyping.Bool,
37 # numpy float
38 np.float16: jaxtyping.Float,
39 np.float32: jaxtyping.Float,
40 np.float64: jaxtyping.Float,
41 np.half: jaxtyping.Float,
42 np.single: jaxtyping.Float,
43 np.double: jaxtyping.Float,
44 # numpy int
45 np.int8: jaxtyping.Int,
46 np.int16: jaxtyping.Int,
47 np.int32: jaxtyping.Int,
48 np.int64: jaxtyping.Int,
49 np.longlong: jaxtyping.Int,
50 np.short: jaxtyping.Int,
51 np.uint8: jaxtyping.Int,
52 # torch float
53 torch.float: jaxtyping.Float,
54 torch.float16: jaxtyping.Float,
55 torch.float32: jaxtyping.Float,
56 torch.float64: jaxtyping.Float,
57 torch.half: jaxtyping.Float,
58 torch.double: jaxtyping.Float,
59 torch.bfloat16: jaxtyping.Float,
60 # torch int
61 torch.int: jaxtyping.Int,
62 torch.int8: jaxtyping.Int,
63 torch.int16: jaxtyping.Int,
64 torch.int32: jaxtyping.Int,
65 torch.int64: jaxtyping.Int,
66 torch.long: jaxtyping.Int,
67 torch.short: jaxtyping.Int,
68}
69"dict mapping python, numpy, and torch types to `jaxtyping` types"
71if np.version.version < "2.0.0":
72 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float
73 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int
76# TODO: add proper type annotations to this signature
77def jaxtype_factory(
78 name: str,
79 array_type: type,
80 default_jax_dtype=jaxtyping.Float,
81 legacy_mode: ErrorMode = ErrorMode.WARN,
82) -> type:
83 """usage:
84 ```
85 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
86 x: ATensor["dim1 dim2", np.float32]
87 ```
88 """
89 legacy_mode = ErrorMode.from_any(legacy_mode)
91 class _BaseArray:
92 """jaxtyping shorthand
93 (backwards compatible with older versions of muutils.tensor_utils)
95 default_jax_dtype = {default_jax_dtype}
96 array_type = {array_type}
97 """
99 def __new__(cls, *args, **kwargs):
100 raise TypeError("Type FArray cannot be instantiated.")
102 def __init_subclass__(cls, *args, **kwargs):
103 raise TypeError(f"Cannot subclass {cls.__name__}")
105 @classmethod
106 def param_info(cls, params) -> str:
107 """useful for error printing"""
108 return "\n".join(
109 f"{k} = {v}"
110 for k, v in {
111 "cls.__name__": cls.__name__,
112 "cls.__doc__": cls.__doc__,
113 "params": params,
114 "type(params)": type(params),
115 }.items()
116 )
118 @typing._tp_cache # type: ignore
119 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore
120 # MyTensor["dim1 dim2"]
121 if isinstance(params, str):
122 return default_jax_dtype[array_type, params]
124 elif isinstance(params, tuple):
125 if len(params) != 2:
126 raise Exception(
127 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
128 )
130 if isinstance(params[0], str):
131 # MyTensor["dim1 dim2", int]
132 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
134 elif isinstance(params[0], tuple):
135 legacy_mode.process(
136 f"legacy type annotation was used:\n{cls.param_info(params) = }",
137 except_cls=Exception,
138 )
139 # MyTensor[("dim1", "dim2"), int]
140 shape_anot: list[str] = list()
141 for x in params[0]:
142 if isinstance(x, str):
143 shape_anot.append(x)
144 elif isinstance(x, int):
145 shape_anot.append(str(x))
146 elif isinstance(x, tuple):
147 shape_anot.append("".join(str(y) for y in x))
148 else:
149 raise Exception(
150 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
151 )
153 return TYPE_TO_JAX_DTYPE[params[1]][
154 array_type, " ".join(shape_anot)
155 ]
156 else:
157 raise Exception(
158 f"unexpected type for params:\n{cls.param_info(params)}"
159 )
161 _BaseArray.__name__ = name
163 if _BaseArray.__doc__ is None:
164 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
166 _BaseArray.__doc__ = _BaseArray.__doc__.format(
167 default_jax_dtype=repr(default_jax_dtype),
168 array_type=repr(array_type),
169 )
171 return _BaseArray
174if typing.TYPE_CHECKING:
175 # these class definitions are only used here to make pylint happy,
176 # but they make mypy unhappy and there is no way to only run if not mypy
177 # so, later on we have more ignores
178 class ATensor(torch.Tensor):
179 @typing._tp_cache # type: ignore
180 def __class_getitem__(cls, params):
181 raise NotImplementedError()
183 class NDArray(torch.Tensor):
184 @typing._tp_cache # type: ignore
185 def __class_getitem__(cls, params):
186 raise NotImplementedError()
189ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment]
191NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment]
194def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
195 """convert numpy dtype to torch dtype"""
196 if isinstance(dtype, torch.dtype):
197 return dtype
198 else:
199 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
202DTYPE_LIST: list = [
203 *[
204 bool,
205 int,
206 float,
207 ],
208 *[
209 # ----------
210 # pytorch
211 # ----------
212 # floats
213 torch.float,
214 torch.float32,
215 torch.float64,
216 torch.half,
217 torch.double,
218 torch.bfloat16,
219 # complex
220 torch.complex64,
221 torch.complex128,
222 # ints
223 torch.int,
224 torch.int8,
225 torch.int16,
226 torch.int32,
227 torch.int64,
228 torch.long,
229 torch.short,
230 # simplest
231 torch.uint8,
232 torch.bool,
233 ],
234 *[
235 # ----------
236 # numpy
237 # ----------
238 # floats
239 np.float16,
240 np.float32,
241 np.float64,
242 np.half,
243 np.single,
244 np.double,
245 # complex
246 np.complex64,
247 np.complex128,
248 # ints
249 np.int8,
250 np.int16,
251 np.int32,
252 np.int64,
253 np.longlong,
254 np.short,
255 # simplest
256 np.uint8,
257 np.bool_,
258 ],
259]
260"list of all the python, numpy, and torch numerical types I could think of"
262if np.version.version < "2.0.0":
263 DTYPE_LIST.extend([np.float_, np.int_])
265DTYPE_MAP: dict = {
266 **{str(x): x for x in DTYPE_LIST},
267 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
268}
269"mapping from string representations of types to their type"
271TORCH_DTYPE_MAP: dict = {
272 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
273}
274"mapping from string representations of types to specifically torch types"
276# no idea why we have to do this, smh
277DTYPE_MAP["bool"] = np.bool_
278TORCH_DTYPE_MAP["bool"] = torch.bool
281TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
282 "Adagrad": torch.optim.Adagrad,
283 "Adam": torch.optim.Adam,
284 "AdamW": torch.optim.AdamW,
285 "SparseAdam": torch.optim.SparseAdam,
286 "Adamax": torch.optim.Adamax,
287 "ASGD": torch.optim.ASGD,
288 "LBFGS": torch.optim.LBFGS,
289 "NAdam": torch.optim.NAdam,
290 "RAdam": torch.optim.RAdam,
291 "RMSprop": torch.optim.RMSprop,
292 "Rprop": torch.optim.Rprop,
293 "SGD": torch.optim.SGD,
294}
297def pad_tensor(
298 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821
299 padded_length: int,
300 pad_value: float = 0.0,
301 rpad: bool = False,
302) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821
303 """pad a 1-d tensor on the left with pad_value to length `padded_length`
305 set `rpad = True` to pad on the right instead"""
307 temp: list[torch.Tensor] = [
308 torch.full(
309 (padded_length - tensor.shape[0],),
310 pad_value,
311 dtype=tensor.dtype,
312 device=tensor.device,
313 ),
314 tensor,
315 ]
317 if rpad:
318 temp.reverse()
320 return torch.cat(temp)
323def lpad_tensor(
324 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
325) -> torch.Tensor:
326 """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
327 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
330def rpad_tensor(
331 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
332) -> torch.Tensor:
333 """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
334 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
337def pad_array(
338 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821
339 padded_length: int,
340 pad_value: float = 0.0,
341 rpad: bool = False,
342) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821
343 """pad a 1-d array on the left with pad_value to length `padded_length`
345 set `rpad = True` to pad on the right instead"""
347 temp: list[np.ndarray] = [
348 np.full(
349 (padded_length - array.shape[0],),
350 pad_value,
351 dtype=array.dtype,
352 ),
353 array,
354 ]
356 if rpad:
357 temp.reverse()
359 return np.concatenate(temp)
362def lpad_array(
363 array: np.ndarray, padded_length: int, pad_value: float = 0.0
364) -> np.ndarray:
365 """pad a 1-d array on the left with pad_value to length `padded_length`"""
366 return pad_array(array, padded_length, pad_value, rpad=False)
369def rpad_array(
370 array: np.ndarray, pad_length: int, pad_value: float = 0.0
371) -> np.ndarray:
372 """pad a 1-d array on the right with pad_value to length `pad_length`"""
373 return pad_array(array, pad_length, pad_value, rpad=True)
376def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
377 """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
378 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
381def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
382 """printable version of get_dict_shapes"""
383 return json.dumps(
384 dotlist_to_nested_dict(
385 {
386 k: str(
387 tuple(v.shape)
388 ) # to string, since indent wont play nice with tuples
389 for k, v in d.items()
390 }
391 ),
392 indent=2,
393 )
396class StateDictCompareError(AssertionError):
397 """raised when state dicts don't match"""
399 pass
402class StateDictKeysError(StateDictCompareError):
403 """raised when state dict keys don't match"""
405 pass
408class StateDictShapeError(StateDictCompareError):
409 """raised when state dict shapes don't match"""
411 pass
414class StateDictValueError(StateDictCompareError):
415 """raised when state dict values don't match"""
417 pass
420def compare_state_dicts(
421 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
422) -> None:
423 """compare two dicts of tensors
425 # Parameters:
427 - `d1 : dict`
428 - `d2 : dict`
429 - `rtol : float`
430 (defaults to `1e-5`)
431 - `atol : float`
432 (defaults to `1e-8`)
433 - `verbose : bool`
434 (defaults to `True`)
436 # Raises:
438 - `StateDictKeysError` : keys don't match
439 - `StateDictShapeError` : shapes don't match (but keys do)
440 - `StateDictValueError` : values don't match (but keys and shapes do)
441 """
442 # check keys match
443 d1_keys: set = set(d1.keys())
444 d2_keys: set = set(d2.keys())
445 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
446 keys_diff_1: set = d1_keys - d2_keys
447 keys_diff_2: set = d2_keys - d1_keys
448 # sort sets for easier debugging
449 symmetric_diff = set(sorted(symmetric_diff))
450 keys_diff_1 = set(sorted(keys_diff_1))
451 keys_diff_2 = set(sorted(keys_diff_2))
452 diff_shapes_1: str = (
453 string_dict_shapes({k: d1[k] for k in keys_diff_1})
454 if verbose
455 else "(verbose = False)"
456 )
457 diff_shapes_2: str = (
458 string_dict_shapes({k: d2[k] for k in keys_diff_2})
459 if verbose
460 else "(verbose = False)"
461 )
462 if not len(symmetric_diff) == 0:
463 raise StateDictKeysError(
464 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
465 )
467 # check tensors match
468 shape_failed: list[str] = list()
469 vals_failed: list[str] = list()
470 for k, v1 in d1.items():
471 v2 = d2[k]
472 # check shapes first
473 if not v1.shape == v2.shape:
474 shape_failed.append(k)
475 else:
476 # if shapes match, check values
477 if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
478 vals_failed.append(k)
480 str_shape_failed: str = (
481 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
482 )
483 str_vals_failed: str = (
484 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
485 )
487 if not len(shape_failed) == 0:
488 raise StateDictShapeError(
489 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
490 )
491 if not len(vals_failed) == 0:
492 raise StateDictValueError(
493 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
494 )