Source code for zero.random

"""Random sampling utilities."""

__all__ = ['set_randomness']

import random
import secrets
from typing import Any, Callable, Optional

import numpy as np
import torch
from numpy.random import Generator, default_rng


def _default_callback(seed):
    print(f'Seed: {seed} (see zero.random.set_randomness)')


[docs]def set_randomness( seed: Optional[int] = None, cudnn_deterministic: bool = True, cudnn_benchmark: bool = False, callback: Optional[Callable[[int], Any]] = _default_callback, ) -> Generator: """Set seeds and settings in `random`, `numpy` and `torch`. Sets random seed for `random`, `numpy.random`, `torch`, `torch.cuda`, sets settings for :code:`torch.backends.cudnn` and builds a NumPy random number generator. Args: seed: the seed for all mentioned libraries. If omitted, a high-quality seed is generated (an integer that **does not fit in int64**). In any case, :code:`seed % (2 ** 32 - 1)` will be used for everything except for building `numpy.random.Generator`. cudnn_deterministic: value for :code:`torch.backends.cudnn.deterministic` cudnn_benchmark: value for :code:`torch.backends.cudnn.benchmark` callback: a function that takes the seed as the only argument. The default callback simply prints the seed via `print` which is convenient when `seed` is set to `None`. Returns: `numpy.random.Generator`: A new style numpy random number generator constructed via `numpy.random.default_rng` (it should be used instead of functions from `np.random`, see `the document <https://numpy.org/doc/stable/reference/random/index.html>`_). Examples: .. testcode:: rng = set_randomness() # seed will be generated rng = set_randomness(0) .. testoutput :: Seed: ... (see zero.random.set_randomness) Seed: 0 (see zero.random.set_randomness) """ torch.backends.cudnn.deterministic = cudnn_deterministic # type: ignore torch.backends.cudnn.benchmark = cudnn_benchmark # type: ignore if seed is None: # See https://numpy.org/doc/1.18/reference/random/bit_generators/index.html#seeding-and-entropy # noqa seed = secrets.randbits(128) raw_seed = seed seed = raw_seed % (2 ** 32 - 1) torch.manual_seed(seed) # mypy doesn't know about the following functions torch.cuda.manual_seed(seed) # type: ignore torch.cuda.manual_seed_all(seed) # type: ignore np.random.seed(seed) random.seed(seed) if callback is not None: callback(raw_seed) return default_rng(raw_seed)