Source code for BOBE.gp

from math import sqrt,pi
from typing import Any,List
import jax.numpy as jnp
import numpy as np
import jax
from jax.scipy.linalg import cho_solve, solve_triangular
jax.config.update("jax_enable_x64", True)
from functools import partial
from .utils.log import get_logger
log = get_logger("gp")
from .optim import optimize_optax, optimize_scipy
from .utils.seed import get_new_jax_key, get_numpy_rng
import numpyro.distributions as dist
from .kernels import Kernel, RBFKernel, MaternKernel


safe_noise_floor = 1e-12

# Constants for DSLP prior
sqrt2 = sqrt(2.)
sqrt3 = sqrt(3.)

[docs] class DummyDistribution: """A dummy distribution that always returns log_prob = 0.0"""
[docs] def log_prob(self, x): return 0.0
[docs] def make_distribution(spec: dict) -> dist.Distribution: """ Turn a dictionary specification into a NumPyro distribution. Parameters ---------- spec : dict Dictionary with 'name' key for distribution type and additional keyword arguments for the distribution parameters. Returns ------- dist.Distribution NumPyro distribution object. Examples -------- >>> spec = {"name": "Normal", "loc": 0.0, "scale": 1.0} >>> dist = make_distribution(spec) """ # Ensure distribution exists dist_class = getattr(dist, spec["name"], None) if dist_class is None: raise ValueError(f"Distribution {spec['name']} not found in numpyro.distributions.") # Remove "name" kwargs = {k: v for k, v in spec.items() if k != "name"} return dist_class(**kwargs)
[docs] def saas_prior_logprob(lengthscales, kernel_variance, tausq): """ Compute SAAS prior log probability. Parameters ---------- lengthscales : jnp.ndarray Lengthscale parameters. kernel_variance : float Kernel variance parameter. tausq : float SAAS tausq parameter. Returns ------- float Log probability under SAAS priors. """ logprior = dist.LogNormal(0., 1.).log_prob(kernel_variance) logprior += dist.HalfCauchy(0.1).log_prob(tausq) inv_lengthscales_sq = 1 / (tausq * lengthscales**2) logprior += jnp.sum(dist.HalfCauchy(1.).log_prob(inv_lengthscales_sq)) return logprior
[docs] @jax.jit def gp_mll(k,train_y,num_points): """ Computes the negative marginal log likelihood of the GP """ L = jnp.linalg.cholesky(k) alpha = cho_solve((L,True),train_y) mll = -0.5*jnp.einsum("ij,ji",train_y.T,alpha) - jnp.sum(jnp.log(jnp.diag(L))) - 0.5*num_points*jnp.log(2*pi) return mll
[docs] @jax.jit def fast_update_cholesky(L: jnp.ndarray, k: jnp.ndarray, k_self: float): # solve L v = k -> v has shape (n,) v = solve_triangular(L, k, lower=True) # new diagonal entry diag = jnp.sqrt(k_self - jnp.dot(v, v)) # print(f"Shapes L: {L.shape}, k: {k.shape}, k_self: {k_self}, v: {v.shape}, diag: {diag.shape}") # build a zero (n+1)x(n+1) and fill blocks n = L.shape[0] new_L = jnp.zeros((n+1, n+1), dtype=L.dtype) new_L = new_L.at[:n, :n].set(L) # top-left new_L = new_L.at[n, :n].set(v) # bottom-left new_L = new_L.at[n, n].set(diag) # bottom-right return new_L
[docs] class GP:
[docs] def __init__(self,train_x,train_y,noise=1e-8,kernel="rbf",optimizer="scipy",optimizer_options={}, kernel_variance_bounds = [1e-4, 1e8],lengthscale_bounds = [0.01,5],lengthscales=None,kernel_variance=None, kernel_variance_prior=None, lengthscale_prior=None, tausq=None, tausq_bounds=[1e-4,1e4], param_names: List[str] = None): """ Initialize the Gaussian Process model. Parameters ---------- train_x : jnp.ndarray Training inputs, shape (N, D). train_y : jnp.ndarray Objective function values at training points, shape (N, 1). noise : float, optional Noise parameter added to the diagonal of the kernel. Default is 1e-8. kernel : str, optional Kernel to use, either "rbf" or "matern". Default is "rbf". optimizer : str, optional Optimizer to use for hyperparameter tuning. Default is "scipy". optimizer_options : dict, optional Keyword arguments for the optimizer. Default is {}. kernel_variance_bounds : list, optional Bounds for the kernel variance. Default is [1e-4, 1e8]. lengthscale_bounds : list, optional Bounds for the lengthscales. Default is [0.01, 10]. lengthscales : jnp.ndarray, optional Initial lengthscale values. If None, defaults to ones. Default is None. kernel_variance : float, optional Initial kernel variance. If None, defaults to 1.0. Default is None. kernel_variance_prior : dict or str, optional Specification for the kernel variance prior. If None, defaults to `{'name': 'LogNormal', 'loc': 0.0, 'scale': 1.0}`. If 'fixed', the kernel variance will be fixed to the initial value and not optimized. Defaults to None. lengthscale_prior : str or dict, optional Specification for the lengthscale prior. If 'DSLP' or None, uses the DSLP prior. If 'SAAS', uses the SAAS prior with tausq parameter. Otherwise, uses the provided distribution spec. Defaults to None. tausq : float, optional Initial tausq parameter for SAAS prior. Only used when lengthscale_prior='SAAS'. If None, defaults to 1.0. Defaults to None. tausq_bounds : list, optional Bounds for the tausq parameter (in log space). Only used when lengthscale_prior='SAAS'. Defaults to [-4, 4]. """ # Setup and validate training data self._setup_training_data(train_x, train_y) self.param_names = param_names if param_names is not None else ['x_'+str(i) for i in range(self.ndim)] # Setup kernel and initial hyperparameters self.kernel_name = kernel if kernel == "rbf" else "matern" self.lengthscales = lengthscales if lengthscales is not None else jnp.ones(self.ndim) self.kernel_variance = kernel_variance if kernel_variance is not None else 1.0 self.noise = noise # Instantiate kernel object kernel_classes = {"rbf": RBFKernel, "matern": MaternKernel} self.kernel = kernel_classes[self.kernel_name](self.lengthscales, self.kernel_variance, self.noise) # Compute initial kernel matrices K = self.kernel.covariance(self.train_x, self.train_x, include_noise=True) self.cholesky = jnp.linalg.cholesky(K) self.alphas = cho_solve((self.cholesky, True), self.train_y) # Setup optimizer self.optimizer_method = optimizer if optimizer == "scipy": self.mll_optimize = optimize_scipy else: self.mll_optimize = optimize_optax self.optimizer_options = optimizer_options # Store bounds self.lengthscale_bounds = lengthscale_bounds self.kernel_variance_bounds = kernel_variance_bounds # Can store tausq for convenience even though it is only used for SAAS self.tausq = tausq if tausq is not None else 1.0 self.tausq_bounds = tausq_bounds # Setup priors and optimization parameters self._setup_kernel_variance_prior(kernel_variance_prior) self._setup_lengthscale_prior(lengthscale_prior) self._setup_optimization_parameters()
def _setup_training_data(self, train_x, train_y): """Setup and validate training data, compute standardization parameters.""" # Check x and y sizes if train_x.shape[0] != train_y.shape[0]: raise ValueError("train_x and train_y must have the same number of points") if train_y.ndim != 2: train_y = train_y.reshape(-1, 1) if train_x.ndim != 2: raise ValueError("train_x must be 2D") self.ndim = train_x.shape[1] # Compute standardization parameters (and handle the case of 0 initialisation points) self.y_mean = jnp.mean(train_y) if train_y.size > 0 else 0 self.y_std = jnp.std(train_y) if train_y.size > 0 else 1.0 # Handle edge case where std is zero (all values identical or only 1 point) if self.y_std == 0: log.warning("Training targets have zero variance. Setting std to 1.0 to avoid division by zero.") self.y_std = 1.0 # Store standardized training data self.train_x = jnp.array(train_x) self.train_y = (train_y - self.y_mean) / self.y_std log.debug(f"GP training size = {self.train_x.shape[0]}") def _setup_kernel_variance_prior(self, kernel_variance_prior): """Setup kernel variance prior and determine if it should be fixed.""" self.kernel_variance_prior_spec = kernel_variance_prior if self.kernel_variance_prior_spec is None: self.kernel_variance_prior_spec = {'name': 'Uniform', 'low': self.kernel_variance_bounds[0], 'high': self.kernel_variance_bounds[1]} # Check if kernel variance should be fixed self.fixed_kernel_variance = (self.kernel_variance_prior_spec == 'fixed') if not self.fixed_kernel_variance: self.kernel_variance_prior_dist = make_distribution(self.kernel_variance_prior_spec) else: self.kernel_variance_prior_dist = DummyDistribution() def _setup_lengthscale_prior(self, lengthscale_prior): """Setup lengthscale prior and determine prior function.""" self.lengthscale_prior_spec = lengthscale_prior if self.lengthscale_prior_spec is None: self.lengthscale_prior_spec = {'name': 'Uniform', 'low': self.lengthscale_bounds[0], 'high': self.lengthscale_bounds[1]} # Set up lengthscale priors and prior function if self.lengthscale_prior_spec == 'DSLP': self.lengthscale_prior_dist = dist.LogNormal(loc=sqrt2 + 0.5*jnp.log(self.ndim), scale=sqrt3) self.prior_func = self._standard_prior_logprob elif self.lengthscale_prior_spec == 'SAAS': self.lengthscale_prior_dist = None self.prior_func = self._saas_prior_logprob else: self.lengthscale_prior_dist = make_distribution(self.lengthscale_prior_spec) self.prior_func = self._standard_prior_logprob def _setup_optimization_parameters(self): """Setup parameter names and bounds for optimization.""" # Build parameter names and bounds based on what's being optimized self.hyperparam_names = ['lengthscales'] self.hyperparam_bounds = [self.lengthscale_bounds] * self.ndim if not self.fixed_kernel_variance: self.hyperparam_names.append('kernel_variance') self.hyperparam_bounds.append(self.kernel_variance_bounds) if self.lengthscale_prior_spec == 'SAAS': self.hyperparam_names.append('tausq') self.hyperparam_bounds.append(self.tausq_bounds) self.hyperparam_bounds = jnp.log(jnp.array(self.hyperparam_bounds).T) self.num_hyperparams = self.hyperparam_bounds.shape[1] log.debug(f" Hyperparameter bounds = {self.hyperparam_bounds}") def _standard_prior_logprob(self, lengthscales, kernel_variance, tausq=None): """Standard prior log probability for DSLP and custom priors.""" logprior = self.kernel_variance_prior_dist.log_prob(kernel_variance) if self.lengthscale_prior_dist is not None: logprior += self.lengthscale_prior_dist.log_prob(lengthscales).sum() return logprior def _saas_prior_logprob(self, lengthscales, kernel_variance, tausq): """SAAS prior log probability.""" return saas_prior_logprob(lengthscales, kernel_variance, tausq) def _parse_hyperparams(self, log_params): """Parse log parameters into lengthscales, kernel_variance, and optionally tausq.""" hyperparams = jnp.exp(log_params) lengthscales = hyperparams[:self.ndim] if self.fixed_kernel_variance: kernel_variance = self.kernel_variance # Use fixed value if 'tausq' in self.hyperparam_names: tausq = hyperparams[self.ndim] if len(hyperparams) > self.ndim else self.tausq else: tausq = self.tausq else: kernel_variance = hyperparams[self.ndim] tausq = hyperparams[self.ndim + 1] if len(hyperparams) > self.ndim + 1 else self.tausq return lengthscales, kernel_variance, tausq
[docs] def neg_mll(self, log_params): """ Computes the negative log marginal likelihood for the GP with given hyperparameters. """ lengthscales, kernel_variance, tausq = self._parse_hyperparams(log_params) # Update kernel hyperparameters and compute kernel matrix self.kernel.update_hyperparams(lengthscales=lengthscales, kernel_variance=kernel_variance) K = self.kernel.covariance(self.train_x, self.train_x, include_noise=True) mll = gp_mll(K, self.train_y, self.train_y.shape[0]) # Add prior mll += self.prior_func(lengthscales, kernel_variance, tausq) return -mll
[docs] def fit(self, x0: np.ndarray = None, maxiter: int = 500) -> dict: """ Performs a serial fit for a given batch of starting points (x0). This method is called by each MPI process on its assigned chunk. Arguments --------- x0 : np.ndarray Array of shape (n_restarts_chunk, n_params) containing starting points for optimization (in log space). maxiter : int Maximum number of iterations for the optimizer. Defaults to 500. Returns ------- result : dict Dictionary containing the best 'mll' and corresponding 'params' (log space) found. """ if x0 is None: # set to current hyperparameters x0 = jnp.log(self.get_hyperparams())[None, :] optimizer_options = self.optimizer_options.copy() best_params_log, best_loss = self.mll_optimize( fun=self.neg_mll, num_params=self.num_hyperparams, bounds=self.hyperparam_bounds, x0=x0, # Use the chunk of starting points passed in maxiter=maxiter, n_restarts=x0.shape[0], # The number of restarts is the size of the chunk optimizer_options=optimizer_options ) # Return the result in the format the pool expects return { 'mll': -best_loss, 'params': best_params_log # Optionally return the raw params }
[docs] def update_hyperparams(self, hyperparams): """ Update the GP hyperparameters and recompute the Cholesky and alphas. """ lengthscales, kernel_variance, tausq = self._parse_hyperparams(hyperparams) self.lengthscales = lengthscales if not self.fixed_kernel_variance: self.kernel_variance = kernel_variance self.tausq = tausq # Update kernel object self.kernel.update_hyperparams(lengthscales=self.lengthscales, kernel_variance=self.kernel_variance) self.recompute_cholesky()
[docs] def predict_mean_single(self,x): """ Single point prediction of mean """ x = jnp.atleast_2d(x) k12 = self.kernel.covariance(self.train_x, x, include_noise=False) # shape (N,1) mean = jnp.einsum('ij,ji', k12.T, self.alphas)*self.y_std + self.y_mean return mean
[docs] def predict_var_single(self,x): x = jnp.atleast_2d(x) k12 = self.kernel.covariance(self.train_x, x, include_noise=False) # shape (N,1) vv = solve_triangular(self.cholesky, k12, lower=True) # shape (N,1) k22 = self.kernel.diagonal(x, include_noise=True) # shape (1,) for x (1,ndim) var = k22 - jnp.sum(vv*vv,axis=0) var = jnp.clip(var, safe_noise_floor, None) return self.y_std**2 * var.squeeze()
[docs] def predict_mean_batched(self,x): x = jnp.atleast_2d(x) return jax.vmap(self.predict_mean_single, in_axes=0)(x)
[docs] def predict_var_batched(self,x): x = jnp.atleast_2d(x) return jax.vmap(self.predict_var_single, in_axes=0)(x)
[docs] def predict_single(self,x): """ Predicts the mean and variance of the GP at x but does not unstandardize it. To use with EI and the like. """ x = jnp.atleast_2d(x) k12 = self.kernel.covariance(self.train_x, x, include_noise=False) k22 = self.kernel.diagonal(x, include_noise=True) mean = jnp.einsum('ij,ji', k12.T, self.alphas) vv = solve_triangular(self.cholesky, k12, lower=True) # shape (N,1) var = k22 - jnp.sum(vv*vv,axis=0) # handle nans and negative variances due to numerical issues var = jnp.where(jnp.isnan(var),safe_noise_floor,var) var = jnp.where(var<safe_noise_floor,safe_noise_floor,var) return mean, var
[docs] def predict_batched(self,x): x = jnp.atleast_2d(x) return jax.vmap(self.predict_single, in_axes=0,out_axes=(0,0))(x)
[docs] def update(self,new_x,new_y): """ Updates the GP with new training points and refits the GP if refit is True. Arguments --------- refit: bool Whether to refit the GP hyperparameters. Default is True. maxiter: int The maximum number of iterations for the optax optimizer. Default is 200. n_restarts: int The number of restarts for the optax optimizer. Default is 4. """ new_x = jnp.atleast_2d(new_x) new_y = jnp.atleast_2d(new_y) duplicate = False new_pts_to_add = [] new_vals_to_add = [] # Check for duplicates and collect new points for i in range(new_x.shape[0]): if jnp.any(jnp.all(jnp.isclose(self.train_x, new_x[i], atol=1e-6, rtol=1e-4), axis=1)): log.debug(f"Point {new_x[i]} already exists in the training set, not updating") else: new_pts_to_add.append(new_x[i]) new_vals_to_add.append(new_y[i]) # Add new points if any if new_pts_to_add: new_pts_to_add = jnp.array(new_pts_to_add) new_vals_to_add = jnp.array(new_vals_to_add) # Add to training data self.train_x = jnp.vstack([self.train_x, new_pts_to_add]) train_y_original = jnp.vstack([self.train_y * self.y_std + self.y_mean, new_vals_to_add]) self.y_mean = jnp.mean(train_y_original) self.y_std = jnp.std(train_y_original) if self.y_std == 0: log.warning("Training targets have zero variance. Setting std to 1.0 to avoid division by zero.") self.y_std = 1.0 self.train_y = (train_y_original - self.y_mean) / self.y_std self.recompute_cholesky()
[docs] def recompute_cholesky(self): """ Recomputes the Cholesky decomposition and alphas. Useful if hyperparameters are changed manually. """ K = self.kernel.covariance(self.train_x, self.train_x, include_noise=True) self.cholesky = jnp.linalg.cholesky(K) self.alphas = cho_solve((self.cholesky, True), self.train_y)
[docs] def fantasy_var(self,new_x,mc_points,k_train_mc): """ Computes the variance of the GP at the mc_points assuming a single point new_x is added to the training set """ new_x = jnp.atleast_2d(new_x) # new_train_x = jnp.concatenate([self.train_x,new_x]) k = self.kernel.covariance(self.train_x, new_x, include_noise=False).flatten() # shape (n,) k_self = self.kernel.diagonal(new_x, include_noise=True)[0] # scalar k11_cho = fast_update_cholesky(self.cholesky,k,k_self) # Compute only the extra row for new_x k_new_mc = self.kernel.covariance(new_x, mc_points, include_noise=False) # shape (1, n_mc) k12 = jnp.vstack([k_train_mc,k_new_mc]) k22 = self.kernel.diagonal(mc_points, include_noise=True) # (N_mc,) vv = solve_triangular(k11_cho, k12, lower=True) # shape (N_train,N_mc) var = k22 - jnp.sum(vv*vv,axis=0) # handle nans and negative variances due to numerical issues var = jnp.where(jnp.isnan(var),safe_noise_floor,var) var = jnp.where(var<safe_noise_floor,safe_noise_floor,var) return var * self.y_std**2 # return to physical scale for better interpretability
[docs] def get_random_point(self,rng=None,nstd=None): """ Returns a random point in the unit cube. """ log.debug(f"Getting random point in unit cube") rng = rng if rng is not None else get_numpy_rng() pt = rng.uniform(0, 1, size=self.train_x.shape[1]) return pt
[docs] def state_dict(self): """ Returns a dictionary containing the complete state of the GP. This can be used for saving, loading, or copying the GP. Returns ------- state: dict Dictionary containing all necessary information to reconstruct the GP """ state = { # Training data (original, unstandardized) 'train_x': np.array(self.train_x), 'train_y': np.array(self.train_y * self.y_std + self.y_mean), # unstandardize # Hyperparameters 'lengthscales': np.array(self.lengthscales), 'kernel_variance': float(self.kernel_variance), 'noise': float(self.noise), 'tausq': float(self.tausq), # Standardization parameters 'y_mean': float(self.y_mean), 'y_std': float(self.y_std), # Model configuration 'kernel_name': self.kernel_name, 'lengthscale_prior_spec': self.lengthscale_prior_spec, 'kernel_variance_prior_spec': self.kernel_variance_prior_spec, 'fixed_kernel_variance': self.fixed_kernel_variance, 'optimizer_method': self.optimizer_method, 'optimizer_options': self.optimizer_options, # Bounds 'lengthscale_bounds': self.lengthscale_bounds, 'kernel_variance_bounds': self.kernel_variance_bounds, 'tausq_bounds': self.tausq_bounds, # Computed state 'cholesky': np.array(self.cholesky) if hasattr(self, 'cholesky') else None, 'alphas': np.array(self.alphas) if hasattr(self, 'alphas') else None, # Dimensions 'ndim': self.ndim, # Class identifier 'gp_class': 'GP' } return state
[docs] @classmethod def from_state_dict(cls, state): """ Creates a GP instance from a state dictionary. Arguments --------- state: dict State dictionary returned by state_dict() Returns ------- gp: GP The reconstructed GP object """ # Create GP instance gp = cls( train_x=state['train_x'], train_y=state['train_y'], noise=state['noise'], kernel=state['kernel_name'], optimizer=state['optimizer_method'], optimizer_options=state['optimizer_options'], lengthscales=state['lengthscales'], kernel_variance=state['kernel_variance'], lengthscale_bounds=state['lengthscale_bounds'], kernel_variance_bounds=state['kernel_variance_bounds'], kernel_variance_prior=state.get('kernel_variance_prior_spec'), lengthscale_prior=state.get('lengthscale_prior_spec'), tausq=state.get('tausq', 1.0), tausq_bounds=state.get('tausq_bounds', [-4, 4]) ) # Restore computed state if available if state['cholesky'] is not None: gp.cholesky = jnp.array(state['cholesky']) if state['alphas'] is not None: gp.alphas = jnp.array(state['alphas']) return gp
[docs] @classmethod def load(cls, filename, **kwargs): """ Loads a GP from a file Arguments --------- filename: str The name of the file to load the GP from (with or without .npz extension) **kwargs: Additional keyword arguments to pass to the GP constructor Returns ------- gp: GP The loaded GP object """ if not filename.endswith('.npz'): filename += '.npz' try: data = np.load(filename, allow_pickle=True) except FileNotFoundError: raise FileNotFoundError(f"Could not find file {filename}") # Convert arrays back to the expected format state = {} for key in data.files: value = data[key] if isinstance(value, np.ndarray) and value.shape == (): # Handle scalar arrays state[key] = value.item() else: state[key] = value # Apply any override kwargs state.update(kwargs) # Use from_state_dict for loading gp = cls.from_state_dict(state) log.info(f"Loaded GP from {filename} with {gp.train_x.shape[0]} training points") return gp
[docs] def save(self, filename='gp'): """ Save the GP state to a file using state_dict. Arguments --------- filename: str The filename to save to (with or without .npz extension). Default is 'gp'. """ if not filename.endswith('.npz'): filename += '.npz' state = self.state_dict() np.savez(filename, **state) log.info(f"Saved GP state to {filename}")
[docs] def copy(self): """ Creates a deep copy of the GP using state_dict. Returns ------- gp_copy: GP A deep copy of the current GP """ state = self.state_dict() return self.__class__.from_state_dict(state)
@property def npoints(self): return self.train_x.shape[0]
[docs] def get_hyperparams(self): hp = self.lengthscales if not self.fixed_kernel_variance: hp = jnp.hstack([hp, self.kernel_variance]) if self.lengthscale_prior_spec == 'SAAS': hp = jnp.hstack([hp, self.tausq]) return hp
[docs] def hyperparams_dict(self): ls_str = {name: f"{float(val):.4f}" for name, val in zip(self.param_names, self.lengthscales)} param_dict = { 'lengthscales': ls_str, 'kernel_variance': f"{float(self.kernel_variance):.4f}", } if 'tausq' in self.hyperparam_names: param_dict['tausq'] = f"{float(self.tausq):.4f}" return param_dict