Source code for BOBE.utils.seed

"""
Utility functions for managing global random seeds across the BOBE package.
"""

import os
import random
import numpy as np
import jax
import jax.random as jax_random
from .log import get_logger
log = get_logger("seed_utils")

# Global variables
_global_seed: int | None = None
_jax_key: jax.Array | None = None
_np_rng: np.random.Generator | None = None


def _ensure_seed_is_set():
    """Internal helper to initialize the global seed if it hasn't been set."""
    if _global_seed is None:
        log.info("Global seed not set. Initializing with a random seed.")
        set_global_seed()


[docs] def set_global_seed(seed: int | None = None) -> int: """Set global random seed for reproducible results. Args: seed: The random seed to use. If None, a random seed is generated. Returns: The seed that was used. """ global _global_seed, _jax_key, _np_rng if seed is None: seed = random.randint(0, 2**31 - 1) log.info(f"No seed provided. Generated a random seed: {seed}") elif not isinstance(seed, int) or seed < 0: raise ValueError("Seed must be a non-negative integer or None") _global_seed = seed random.seed(seed) _np_rng = np.random.default_rng(seed) _jax_key = jax_random.PRNGKey(seed) os.environ['PYTHONHASHSEED'] = str(seed) log.info(f"Global random seed set to {seed}") return seed
[docs] def get_global_seed() -> int: """Get the current global seed value. If the seed has not been set, it will be initialized automatically. Returns: The current global seed. """ _ensure_seed_is_set() return _global_seed
[docs] def get_jax_key() -> jax.Array: """Get the current JAX random key. If the seed has not been set, it will be initialized automatically. Returns: The current JAX PRNGKey. """ _ensure_seed_is_set() return _jax_key
[docs] def split_jax_key() -> tuple[jax.Array, jax.Array]: """Split the current JAX random key and update the global key. If the seed has not been set, it will be initialized automatically. Returns: A tuple containing the new global key and the key for use. """ global _jax_key _ensure_seed_is_set() _jax_key, use_key = jax_random.split(_jax_key) return _jax_key, use_key
[docs] def get_new_jax_key() -> jax.Array: """Get a new JAX random key by splitting the current global key. If the seed has not been set, it will be initialized automatically. Returns: A new JAX PRNGKey for immediate use. """ _, use_key = split_jax_key() return use_key
[docs] def get_numpy_rng() -> np.random.Generator: """Get the global NumPy random number generator. If the seed has not been set, it will be initialized automatically. Returns: The global instance of numpy.random.Generator. """ _ensure_seed_is_set() return _np_rng
[docs] def ensure_reproducibility(seed: int | None = None) -> int: """Ensure reproducibility by setting seeds and JAX configurations. Args: seed: The seed to use. If None, a random seed will be generated. Returns: The seed that was used. """ used_seed = set_global_seed(seed) jax.config.update("jax_enable_x64", True) log.info(f"Reproducibility ensured with seed {used_seed}") return used_seed