Source code for BOBE.samplers

# This module manages the samplers used to run HMC/Nested sampling using the GP model as a surrogate for the objective function
# It contains two functions, one for the Dynesty nested sampler and the other for the HMC sampler using NUTS from numpyro
import time
from typing import Any, List, Optional, Dict, Union
import jax.numpy as jnp
import jax.random as random
import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
from numpyro.util import enable_x64
enable_x64()
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.initialization import init_to_value
from .gp import GP
from .clf_gp import GPwithClassifier
from .utils.log import get_logger
from .utils.seed import get_new_jax_key, get_numpy_rng
from .utils.core import is_cluster_environment, renormalise_log_weights, resample_equal
log = get_logger("sampler")

from dynesty import NestedSampler as StaticNestedSampler, DynamicNestedSampler
import math

# dynesty utility function for computing evidence
[docs] def compute_integrals(logl=None, logvol=None, reweight=None,squared=False): assert logl is not None assert logvol is not None loglstar_pad = np.concatenate([[-1.e300], logl]) # we want log(exp(logvol_i)-exp(logvol_(i+1))) # assuming that logvol0 = 0 # log(exp(LV_{i})-exp(LV_{i+1})) = # = LV{i} + log(1-exp(LV_{i+1}-LV{i})) # = LV_{i+1} - (LV_{i+1} -LV_i) + log(1-exp(LV_{i+1}-LV{i})) dlogvol = np.diff(logvol, prepend=0) logdvol = logvol - dlogvol + np.log1p(-np.exp(dlogvol)) if squared: logdvol = 2 * logdvol # logdvol is log(delta(volumes)) i.e. log (X_i-X_{i-1}) logdvol2 = logdvol + math.log(0.5) # These are log(1/2(X_(i+1)-X_i)) dlogvol = -np.diff(logvol, prepend=0) # this are delta(log(volumes)) of the run # These are log((L_i+L_{i_1})*(X_i+1-X_i)/2) saved_logwt = np.logaddexp(loglstar_pad[1:], loglstar_pad[:-1]) + logdvol2 if reweight is not None: saved_logwt = saved_logwt + reweight saved_logz = np.logaddexp.accumulate(saved_logwt) return saved_logz
[docs] def prior_transform(x): return x
[docs] def nested_sampling_Dy(gp: GP, mode: str = 'acq', ndim: int = 1, dlogz: float = 0.1, dynamic: bool = False, maxcall: Optional[int] = int(5e6), print_progress: Optional[bool] = True, equal_weights: bool = False, sample_method: str = 'rwalk', rng=None, ) -> tuple[np.ndarray,Dict,bool]: """ Nested Sampling using Dynesty Arguments --------- gp : saas_fbgp Gaussian Process model ndim : int Number of dimensions dlogz : float Log evidence goal dynamic : bool Use dynamic nested sampling, see Dynesty documentation for more details logz_std : bool Compute the upper and lower bounds on logZ using the GP uncertainty maxcall : int Maximum number of function calls boost_maxcall : int Boost the maximum number of function calls print_progress : bool, optional Print progress of the nested sampling run. If None, automatically disables progress printing in cluster environments and enables it otherwise. equal_weights : bool Resample to obtain equal weights sample_method : str Sampling method for dynesty rng : random number generator Random number generator Returns ------- samples : ndarray Equally weighted samples from the nested sampler logz_dict : dict Dictionary containing the mean, upper and lower bounds on logZ and the logZ error from the nested sampler success : bool Whether the nested sampling run was successful """ log.info("Running Nested Sampling using Dynesty...") # Auto-detect cluster environment if print_progress not explicitly set if is_cluster_environment(): print_progress = False @jax.jit def loglike(x): mu = gp.predict_mean_single(x) return mu start = time.time() if mode == 'acq': # a bit lower precision settings for acquisition nlive = max(100, min(500, 20 * ndim)) dlogz = 0.1 maxcall = int(2e6) equal_weights = True else: nlive = max(500, 40 * ndim) rng = rng if rng is not None else get_numpy_rng() if isinstance(gp, GPwithClassifier): maxtries = 1000 nlogl = 5000 * ndim x = rng.uniform(low=0., high=1., size=(nlogl, ndim)) logl = jax.lax.map(loglike,x,batch_size=200) logl = np.array(logl) success = False for i in range(maxtries): live_indices = rng.choice(nlogl, size=nlive, replace=False) live_logl = logl[live_indices] if np.all(live_logl == live_logl[0]): log.debug(f" All logl values are the same on try {i+1}/{maxtries}. Retrying...") else: log.debug(f" Successful live points on try {i+1}/{maxtries}.") success = True break live_points = x[live_indices] live_logl = logl[live_indices] if not success: valid_point = gp.get_random_point(rng=rng,nstd=1.0) valid_logl = float(loglike(valid_point)) live_points[0] = valid_point live_logl[0] = valid_logl else: live_points = rng.uniform(low=0., high=1., size=(nlive, ndim)) live_logl = jax.lax.map(loglike,live_points,batch_size=200) live_logl = np.array(live_logl) sampler = StaticNestedSampler(loglike, prior_transform, ndim=ndim, blob=False ,live_points=[live_points,live_points,live_logl] ,sample=sample_method, nlive=nlive, rstate=rng) sampler.run_nested(print_progress=print_progress,dlogz=dlogz,maxcall=maxcall) res = sampler.results mean = res['logz'][-1] logz_err = res['logzerr'][-1] samples_x = res['samples'] logl = res['logl'] success = ~np.all(logl == logl[0]) # in case of failure do not check convergence log.debug(f" Nested Sampling took {time.time() - start:.2f}s") log.debug(" Log Z evaluated using {} points".format(np.shape(logl))) log.debug(f" Dynesty made {np.sum(res['ncall'])} function calls, max value of logl = {np.max(logl):.4f}") var = jax.lax.map(gp.predict_var_single,samples_x,batch_size=100) std = np.sqrt(var) logl_lower,logl_upper = logl - std, logl + std logvol = res['logvol'] upper = compute_integrals(logl=logl_upper,logvol=logvol) lower = compute_integrals(logl=logl_lower,logvol=logvol) var = np.clip(var,a_min=1e-12,a_max=1e12) varintegrand = 2*logl + np.log(var) log_var_delta = compute_integrals(logl=varintegrand,logvol=logvol,squared=True)[-1] log_var_logz = log_var_delta - 2*mean log_var_logz = np.clip(log_var_logz, a_min=-100, a_max=100) # Avoid numerical issues with very small or large variances var_logz = np.exp(log_var_logz) logz_dict = {'mean': mean, 'dlogz_sampler': logz_err, 'upper': upper[-1], 'lower': lower[-1], 'var': var_logz, 'std': 2*np.sqrt(var_logz)} best_pt = samples_x[np.argmax(logl)] weights = renormalise_log_weights(res['logwt']) if equal_weights: #for MC points samples_x, logl = resample_equal(samples_x, logl, weights=weights) weights = np.ones(samples_x.shape[0]) # Equal weights after resampling samples_dict = {'x': samples_x,'weights': weights,'logl': logl,'best': best_pt,'method': 'nested'} samples_dict['x'] = samples_x samples_dict['weights'] = weights return (samples_dict, logz_dict, success)
[docs] def sample_GP_NUTS(gp: Union[GP, GPwithClassifier], np_rng=None, rng_key=None, num_chains=4, temp=1., **kwargs): """ Obtain samples from the posterior represented by the GP mean as the logprob. This is a unified function that works for both GP and GPwithClassifier. Parameters ---------- gp : Union[GP, GPwithClassifier] The Gaussian Process model to sample from. np_rng : np.random.Generator, optional NumPy random number generator. Default is None. rng_key : jax.random.PRNGKey, optional JAX random key. Default is None. num_chains : int, optional Number of parallel HMC chains. Default is 4. temp : float, optional Temperature parameter for tempering. Default is 1.0. **kwargs : dict Additional keyword arguments. Can include: - warmup_steps : int, optional Number of warmup steps for HMC. If not provided, defaults based on dimensionality. - num_samples : int, optional Number of samples to draw from each chain. If not provided, defaults based on dimensionality. - thinning : int, optional Thinning factor for samples. If not provided, defaults to 4. - dense_mass : bool, optional Whether to use dense mass matrix in NUTS. Default is True. - max_tree_depth : int, optional Maximum tree depth for NUTS. Default is 6. Returns ------- samples_dict : dict Dictionary containing: - 'x': samples array of shape (num_chains * num_samples / thinning, ndim) - 'logp': log probabilities for each sample - 'best': best sample found - 'method': 'MCMC' """ # Extract HMC settings from kwargs with simple fallback defaults # Note: Dimension-based defaults are now handled centrally in bo.py warmup_steps = kwargs.get('warmup_steps', 512) num_samples = kwargs.get('num_samples', 1024) thinning = kwargs.get('thinning', 4) dense_mass = kwargs.get('dense_mass', True) max_tree_depth = kwargs.get('max_tree_depth', 6) shape = gp.train_x.shape[1] def model(): x = numpyro.sample('x', dist.Uniform( low=jnp.zeros(shape), high=jnp.ones(shape) )) mean = gp.predict_mean_batched(x) numpyro.factor('y', mean/temp) numpyro.deterministic('logp', mean) @jax.jit def run_single_chain(rng_key,init_x): init_strategy = init_to_value(values={'x': init_x}) kernel = NUTS(model, dense_mass=dense_mass, max_tree_depth=max_tree_depth, init_strategy=init_strategy) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, num_chains=1, progress_bar=False, thinning=thinning) mcmc.run(rng_key) samples_x = mcmc.get_samples()['x'] logps = mcmc.get_samples()['logp'] return samples_x, logps num_devices = jax.device_count() rng_key = rng_key if rng_key is not None else get_new_jax_key() rng_keys = jax.random.split(rng_key, num_chains) # Generate initialization points if needed if num_chains == 1: inits = jnp.array([gp.get_random_point(rng=np_rng)]) else: inits = jnp.vstack([gp.get_random_point(rng=np_rng) for _ in range(num_chains-1)]) inits = jnp.vstack([inits, gp.train_x[jnp.argmax(gp.train_y)]]) # Add best training point as one init log.debug(f"Running MCMC with {num_chains} chains on {num_devices} devices.") # Adaptive method selection based on device/chain configuration if num_devices == 1: # Sequential method for single device log.debug("Using sequential method (single device)") samples_x = [] logps = [] for i in range(num_chains): samples_x_i, logps_i = run_single_chain(rng_keys[i], inits[i]) samples_x.append(samples_x_i) logps.append(logps_i) samples_x = jnp.concatenate(samples_x) logps = jnp.concatenate(logps) elif num_devices >= num_chains and num_chains > 1: # Direct pmap method when devices >= chains log.debug("Using direct pmap method (devices >= chains)") pmapped = jax.pmap(run_single_chain, in_axes=(0, 0), out_axes=(0, 0)) samples_x, logps = pmapped(rng_keys, inits) samples_x = jnp.concatenate(samples_x, axis=0) logps = jnp.concatenate(logps, axis=0) logps = jnp.reshape(logps, (samples_x.shape[0],)) elif 1 < num_devices < num_chains: # Chunked method when devices < chains (but > 1 device) log.debug(f"Using chunked pmap method ({num_devices} devices < {num_chains} chains)") # Process chains in chunks of device count using the existing run_single_chain pmapped_chunked = jax.pmap(run_single_chain, in_axes=(0, 0), out_axes=(0, 0)) all_samples = [] all_logps = [] for i in range(0, num_chains, num_devices): end_idx = min(i + num_devices, num_chains) chunk_keys = rng_keys[i:end_idx] chunk_inits = inits[i:end_idx] # Run chunk (pmap handles variable chunk sizes automatically) chunk_samples, chunk_logps = pmapped_chunked(chunk_keys, chunk_inits) all_samples.append(chunk_samples) all_logps.append(chunk_logps) # Concatenate all chunks samples_x = jnp.concatenate([jnp.concatenate(chunk, axis=0) for chunk in all_samples], axis=0) logps = jnp.concatenate([jnp.concatenate(chunk, axis=0) for chunk in all_logps], axis=0) samples_dict = { 'x': samples_x, 'logp': logps, 'best': samples_x[jnp.argmax(logps)], 'method': "MCMC" } log.debug(f"Max logl found in HMC = {np.max(logps):.4f}") return samples_dict