Source code for BOBE.acquisition

from typing import Any, List, Optional, Dict, Tuple
import jax
import jax.numpy as jnp
from jax import lax,jit
import numpy as np
from scipy.stats import qmc
from jax.scipy.stats import norm
from jax import config
import tensorflow_probability.substrates.jax as tfp
from .optim import optimize_optax, optimize_optax_vmap, optimize_scipy
from .utils.log import get_logger
from .utils.seed import get_numpy_rng
from .samplers import nested_sampling_Dy, sample_GP_NUTS
from .gp import GP
config.update("jax_enable_x64", True)
log = get_logger("acq")

#------------------Helper functions-------------------------
# These are jax versions of the BoTorch functions. 

def _scaled_improvement(mu, sigma, best_f):
    """u = (mu - best_f) / sigma, safe for sigma=0."""
    return (mu - best_f) / sigma

def _log_phi(u):
    """log of standard normal PDF"""
    return -0.5 * (u**2 + jnp.log(2 * jnp.pi))

def _ei_helper(u):
    """EI = phi(u) + u * Phi(u), stable for large |u|."""
    return norm.pdf(u) + u * norm.cdf(u)

def _log_abs_u_Phi_div_phi(u):
    """
    log(|u| * Phi(u) / phi(u)), valid for u < 0.
    Uses erfcx for numerical stability in tail.
    """
    neg_inv_sqrt2 = -1.0 / jnp.sqrt(2.0)
    log_sqrt_pi_div_2 = 0.5 * jnp.log(jnp.pi / 2.0)

    erfcx_val = tfp.math.erfcx(neg_inv_sqrt2 * u)
    return jnp.log(jnp.abs(u) * erfcx_val) + log_sqrt_pi_div_2

def _log_ei_helper(u):
    """
    Accurately computes log(phi(u) + u * Phi(u)).
    Matches BoTorch branching for stability, based on Ament et al., [arxiv: 2310.20708].
    """
    if u.dtype not in [jnp.float32, jnp.float64]:
        raise TypeError(
            f"LogExpectedImprovement only supports float32 and float64, got {u.dtype}."
        )

    bound = -1.0
    neg_inv_sqrt_eps = -1e6 if u.dtype == jnp.float64 else -1e3

    # u > bound: directly log(EI)
    u_upper = jnp.where(u < bound, bound, u)
    log_ei_upper = jnp.log(_ei_helper(u_upper))

    # u <= bound: use asymptotic expansion
    u_lower = jnp.where(u > bound, bound, u)
    u_eps = jnp.where(u_lower < neg_inv_sqrt_eps, neg_inv_sqrt_eps, u_lower)

    w = _log_abs_u_Phi_div_phi(u_eps)
    log_phi_u = _log_phi(u)

    second_term = jnp.where(
        u > neg_inv_sqrt_eps,
        tfp.math.log1mexp(w),
        -2.0 * jnp.log(jnp.abs(u_lower))
    )
    log_ei_lower = log_phi_u + second_term

    return jnp.where(u > bound, log_ei_upper, log_ei_lower)

#------------------The acquisition function classes-------------------------

[docs] class AcquisitionFunction: """Base class for acquisition functions. Acquisition functions guide the selection of new points to evaluate by balancing exploration and exploitation. Subclasses must implement the `fun` and `get_next_point` methods. Attributes ---------- name : str Name of the acquisition function. optimizer : str Optimizer to use ('scipy' or 'optax'). optimizer_options : dict Additional options for the optimizer. acq_optimize : callable Optimization function (optimize_scipy or optimize_optax). """ name: str = "BaseAcquisitionFunction"
[docs] def __init__(self, optimizer: str = "scipy", optimizer_options: Optional[Dict[str, Any]] = {}): self.optimizer = optimizer self.optimizer_options = optimizer_options if self.optimizer == "scipy": self.acq_optimize = optimize_scipy else: self.acq_optimize = optimize_optax
[docs] def fun(self, x): raise NotImplementedError
[docs] def get_next_point(self, gp: GP, acq_kwargs: Dict[str, Any] = {}, maxiter: int = 500, n_restarts: int = 8, verbose: bool = True, early_stop_patience: int = 25, rng=None) -> Tuple[np.ndarray, float]: """ Optimize the acquisition function to obtain the next point to sample. Parameters ---------- gp : GP Gaussian process model. acq_kwargs : dict, optional Additional arguments for the acquisition function. Default is {}. maxiter : int, optional Maximum number of optimization iterations. Default is 500. n_restarts : int, optional Number of random restarts for optimization. Default is 8. verbose : bool, optional Whether to print optimization progress. Default is True. early_stop_patience : int, optional Patience for early stopping. Default is 25. rng : np.random.Generator, optional Random number generator. Default is None. Returns ------- tuple (best_point, best_value) where best_point is shape (ndim,) and best_value is the acquisition function value. """ raise NotImplementedError("Base class get_next() not implemented")
[docs] def get_next_batch(self, gp: GP, n_batch: int = 1, acq_kwargs: Dict[str, Any] = {}, maxiter: int = 500, n_restarts: int = 8, verbose: bool = True, early_stop_patience: int = 25, rng=None) -> Tuple[np.ndarray, float]: """ Get the next batch of points to sample. """ rng = rng if rng is not None else get_numpy_rng() x_batch, acq_vals = [], [] x_next, acq_val_next = self.get_next_point(gp, acq_kwargs=acq_kwargs, maxiter=maxiter, n_restarts=n_restarts, verbose=verbose, early_stop_patience=early_stop_patience, rng=rng) x_batch.append(x_next) acq_vals.append(acq_val_next) if n_batch > 1: # Create dummy GP without classifier functionality, for now we do not use batching for EI/LogEI dummy_gp = GP(train_x=gp.train_x, train_y=gp.train_y*gp.y_std + gp.y_mean, noise=gp.noise, kernel=gp.kernel_name, lengthscales=gp.lengthscales, kernel_variance=gp.kernel_variance,) dummy_gp.update(x_next, dummy_gp.predict_mean_single(x_next)) for i in range(1,n_batch): x_next, acq_val_next = self.get_next_point(dummy_gp, acq_kwargs=acq_kwargs, maxiter=maxiter, n_restarts=n_restarts, verbose=verbose, early_stop_patience=early_stop_patience, rng=rng) x_batch.append(x_next) acq_vals.append(acq_val_next) mu = dummy_gp.predict_mean_single(x_next) dummy_gp.update(x_next, mu) return np.array(x_batch), np.array(acq_vals)
[docs] class EI(AcquisitionFunction): """Expected Improvement acquisition function. EI measures the expected improvement over the current best observed value. It balances exploitation (high mean) and exploration (high uncertainty). The EI criterion is defined as: EI(x) = E[max(f(x) - f_best - zeta, 0)] where f_best is the best observed value and zeta is an exploration bonus. Parameters ---------- optimizer : str, optional Optimizer to use ('scipy' or 'optax'). Default is 'scipy'. optimizer_options : dict, optional Additional options for the optimizer. Default is {}. """ name: str = "EI"
[docs] def __init__(self, optimizer: str = "scipy", optimizer_options: Optional[Dict[str, Any]] = {}): super().__init__(optimizer=optimizer, optimizer_options=optimizer_options) if optimizer == 'optax': self.acq_optimize = optimize_optax_vmap
[docs] def fun(self, x, gp, best_y, zeta): """ Compute Expected Improvement at point x. Parameters ---------- x : jnp.ndarray Point at which to evaluate EI, shape (ndim,). gp : GP Gaussian process model. best_y : float Best observed function value. zeta : float Exploration bonus parameter. Returns ------- float Negative expected improvement (for minimization). """ mu, var = gp.predict_single(x) var = jnp.clip(var, a_min=1e-20) # prevent zero variance sigma = jnp.sqrt(var) u = _scaled_improvement(mu - zeta, sigma, best_y) ei = _ei_helper(u) * sigma return jnp.reshape(-ei, ()) # optimizer minimizes this
[docs] def get_next_point(self, gp, acq_kwargs, maxiter: int = 250, n_restarts: int = 20, verbose: bool = True, early_stop_patience: int = 25, rng=None): rng = rng if rng is not None else get_numpy_rng() zeta = acq_kwargs.get('zeta', 0.) best_y = acq_kwargs.get('best_y', max(gp.train_y.flatten())) fun_args = (gp, best_y, zeta) fun_kwargs = {} best_x = gp.train_x[jnp.argmax(gp.train_y)] # For Classifier GP, we make sure to get points inside the positive region if n_restarts > 1: n_random_restarts = int(n_restarts/2) x0_acq = jnp.vstack([gp.get_random_point(rng,nstd=5) for _ in range(n_random_restarts)]) n_best_restarts = n_restarts - n_random_restarts # print(f'shape x0_acq: {x0_acq.shape}, best_x shape: {best_x.shape}, nrestarts: {n_restarts}, n_random: {n_random_restarts}, n_best: {n_best_restarts}') x0_acq = jnp.vstack([x0_acq, jnp.full((n_best_restarts, gp.ndim), best_x)]) else: x0_acq = best_x jitter = rng.normal(0.,0.005,size=x0_acq.shape) x0_acq = jnp.clip(x0_acq + jitter, 0., 1.) pts, vals = self.acq_optimize(fun=self.fun, fun_args=fun_args, fun_kwargs=fun_kwargs, num_params=gp.ndim, x0=x0_acq, bounds = [0,1], optimizer_options=self.optimizer_options, maxiter=maxiter, n_restarts=n_restarts, verbose=verbose) return pts, -vals # we minimize -EI so return -vals
[docs] class LogEI(EI): """Log Expected Improvement acquisition function. LogEI computes the logarithm of the Expected Improvement, providing better numerical stability compared to EI, especially when EI values are very small. Uses advanced numerical techniques for accurate computation in extreme cases. Parameters ---------- optimizer : str, optional Optimizer to use ('scipy' or 'optax'). Default is 'scipy'. optimizer_options : dict, optional Additional options for the optimizer. Default is {}. References ---------- [1] Ament, S., et al. (2023). "Unexpected Improvements to Expected Improvement for Bayesian Optimization." arXiv:2310.20708. """ name: str = "LogEI"
[docs] def __init__(self, optimizer: str = "scipy", optimizer_options: Optional[Dict[str, Any]] = {}): super().__init__(optimizer=optimizer, optimizer_options=optimizer_options)
[docs] def fun(self, x, gp, best_y, zeta): """ Log Expected Improvement in pure JAX. Returns *positive* log-EI, so you can maximize directly. """ mu, var = gp.predict_single(x) var = jnp.clip(var, a_min=1e-18) # prevent zero variance sigma = jnp.sqrt(var) u = _scaled_improvement(mu - zeta, sigma, best_y) log_ei = _log_ei_helper(u) + jnp.log(sigma) return jnp.reshape(-log_ei, ()) # optimizer minimizes this
[docs] class WeightedIntegratedPosteriorBase(AcquisitionFunction): """Base class for Weighted Integrated Posterior acquisition functions. This base class provides common functionality for acquisition functions that integrate over MC samples from the GP posterior, such as WIPV and WIPStd. Parameters ---------- optimizer : str, optional Optimizer to use ('scipy' or 'optax'). Default is 'scipy'. optimizer_options : dict, optional Additional options for the optimizer. Default is {}. """
[docs] def __init__(self, optimizer: str = "scipy", optimizer_options: Optional[Dict[str, Any]] = {}): super().__init__(optimizer=optimizer, optimizer_options=optimizer_options)
[docs] def get_next_point(self, gp, acq_kwargs, maxiter: int = 100, n_restarts: int = 1, verbose: bool = True, early_stop_patience: int = 25, rng=None): """ Optimize the acquisition function to obtain the next point. This method is shared between WIPV and WIPStd as they follow the same optimization procedure but differ only in their objective function. Parameters ---------- gp : GP Gaussian process model. acq_kwargs : dict Additional arguments containing 'mc_samples' and optionally 'mc_points_size'. maxiter : int, optional Maximum optimization iterations. Default is 100. n_restarts : int, optional Number of optimization restarts. Default is 1. verbose : bool, optional Whether to print progress. Default is True. early_stop_patience : int, optional Early stopping patience. Default is 25. rng : np.random.Generator, optional Random number generator. Default is None. Returns ------- tuple (best_point, best_value) where best_point is shape (ndim,). """ mc_samples = acq_kwargs.get('mc_samples') mc_points_size = acq_kwargs.get('mc_points_size', 128) mc_points = get_mc_points(mc_samples, mc_points_size=mc_points_size, rng=rng) k_train_mc = gp.kernel.covariance(gp.train_x, mc_points, include_noise=False) @jax.jit def mapped_fn(x): return self.fun(x, gp, mc_points=mc_points, k_train_mc=k_train_mc) acq_vals = lax.map(mapped_fn, mc_points) acq_val_min = jnp.min(acq_vals) log.debug(f"{self.name} acquisition min value on MC points: {float(acq_val_min):.4e}") best_x = mc_points[jnp.argmin(acq_vals)] x0_acq = best_x if gp.train_x.shape[0] > 500: return x0_acq, float(acq_val_min) else: return self.acq_optimize(fun=self.fun, fun_args=(gp,), fun_kwargs={'mc_points': mc_points, 'k_train_mc': k_train_mc}, num_params=gp.ndim, x0=x0_acq, bounds = [0,1], optimizer_options=self.optimizer_options, maxiter=maxiter, n_restarts=n_restarts, verbose=verbose)
[docs] class WIPV(WeightedIntegratedPosteriorBase): """Weighted Integrated Posterior Variance acquisition function. WIPV focuses on reducing uncertainty in regions weighted by posterior probability. It integrates the posterior variance over MC samples drawn from the GP posterior, making it particularly effective for Bayesian evidence estimation. The criterion is defined as: WIPV(x) = E_{x' ~ p(x' | D)}[Var[f(x) | D]] where the expectation is over MC samples x' from the posterior. Parameters ---------- optimizer : str, optional Optimizer to use ('scipy' or 'optax'). Default is 'scipy'. optimizer_options : dict, optional Additional options for the optimizer. Default is {}. """ name: str = "WIPV"
[docs] def fun(self, x, gp, mc_points=None, k_train_mc = None): var = gp.fantasy_var(new_x=x, mc_points=mc_points,k_train_mc=k_train_mc) return jnp.mean(var)
[docs] class WIPStd(WeightedIntegratedPosteriorBase): """Weighted Integrated Posterior Standard Deviation acquisition function. WIPStd is similar to WIPV but uses standard deviation instead of variance, which can provide different exploration characteristics. It integrates the posterior standard deviation over MC samples from the GP posterior. The criterion is defined as: WIPStd(x) = E_{x' ~ p(x' | D)}[Std[f(x) | D]] Parameters ---------- optimizer : str, optional Optimizer to use ('scipy' or 'optax'). Default is 'scipy'. optimizer_options : dict, optional Additional options for the optimizer. Default is {}. """ name: str = "WIPStd"
[docs] def fun(self, x, gp, mc_points=None, k_train_mc = None): std = jnp.sqrt(gp.fantasy_var(new_x=x, mc_points=mc_points,k_train_mc=k_train_mc)) return jnp.mean(std)
[docs] def get_mc_samples(gp: GP,warmup_steps=512, num_samples=1024, thinning=4,method="NUTS",num_chains=4,np_rng=None,rng_key=None): if method=='NUTS': mc_samples = sample_GP_NUTS(gp=gp, warmup_steps=warmup_steps, num_samples=num_samples, thinning=thinning, num_chains=num_chains,np_rng=np_rng,rng_key=rng_key ) elif method=='NS': mc_samples, logz, success = nested_sampling_Dy(gp=gp, ndim=gp.ndim, mode = 'acq', maxcall=int(2e6), dynamic=False, dlogz=0.02,equal_weights=True, rng=np_rng) elif method=='uniform': mc_samples = {} points = qmc.Sobol(gp.ndim, scramble=True, rng=np_rng).random(num_samples) mc_samples['x'] = points else: raise ValueError(f"Unknown method {method} for sampling GP") return mc_samples
[docs] def get_mc_points(mc_samples, mc_points_size=128, rng=None): mc_size = max(mc_samples['x'].shape[0], mc_points_size) rng = rng if rng is not None else get_numpy_rng() idxs = rng.choice(mc_size, size=mc_points_size, replace=False) return mc_samples['x'][idxs]