Source code for BOBE.clf_gp

import time
import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from .gp import GP, safe_noise_floor
from .clf import (
    CLASSIFIER_REGISTRY
)
from .utils.seed import get_new_jax_key, get_numpy_rng
from .utils.log import get_logger
from .utils.core import get_threshold_for_nsigma
log = get_logger("clf_gp")


[docs] class GPwithClassifier(GP):
[docs] def __init__(self, train_x=None, train_y=None, clf_type='svm', clf_settings={}, clf_use_size=10, clf_update_step=1, probability_threshold=0.5, minus_inf=-1e5, clf_threshold=250., gp_threshold=500., noise=1e-8, kernel="rbf", optimizer="scipy", optimizer_options={}, kernel_variance_bounds=[1e-4, 1e8], lengthscale_bounds=[0.01, 5.], tausq=None, tausq_bounds=[1e-4, 1e4], kernel_variance_prior=None, lengthscale_prior=None, lengthscales=None, kernel_variance=1.0, param_names=None, train_clf_on_init=True, # Prevent retraining on copy ): """ Generic Classifier-GP class combining a GP with a classifier. The GP is trained on the data points that are within the GP threshold of the maximum value of the GP. Arguments --------- train_x : array-like, shape (n_samples, n_dim) Initial training points. train_y : array-like, shape (n_samples,) Initial training values. clf_type : str, optional Type of classifier ('svm', 'nn', 'ellipsoid', etc.). Default is 'svm'. clf_params : dict, optional Parameters specific to the chosen classifier. Default is None. clf_use_size : int, optional Minimum number of points to start using the classifier. Default is 300. clf_update_step : int, optional Update classifier every `clf_update_step` points after `clf_use_size` is reached. Default is 5. probability_threshold : float, optional Threshold for classifier probability/score to consider a point feasible (important for nn, ellipsoid). Default is 0.5. minus_inf : float, optional Value used for infeasible predictions. Default is -1e5. clf_threshold : float, optional Threshold for initial classifier training labels (if used). If None, `gp_threshold` might be used or a default calculated. gp_threshold : float, optional Threshold for adding points to the GP training set. Default is 5000. noise, kernel, optimizer, kernel_variance_bounds, lengthscale_bounds, lengthscale_priors, lengthscales, kernel_variance: GP parameters (see DSLP_GP/SAAS_GP). Note: bounds are now in actual space, not log10. """ # Store Data and Classifier Settings self.train_x_clf = jnp.array(train_x) self.train_y_clf = jnp.array(train_y).reshape(-1, 1) # Ensure 2D self.clf_use_size = clf_use_size self.clf_update_step = clf_update_step self.clf_type = clf_type.lower() self.clf_settings = clf_settings self.clf_params = None self.clf_metrics = {} self.probability_threshold = probability_threshold self.minus_inf = minus_inf # Store classifier functions and settings if self.clf_type not in CLASSIFIER_REGISTRY: raise ValueError(f"Unsupported classifier type: {self.clf_type}") self.clf_train_fn = CLASSIFIER_REGISTRY[self.clf_type]['train_fn'] self.clf_predict_fn = CLASSIFIER_REGISTRY[self.clf_type]['predict_fn'] # Handle Thresholds self.clf_threshold = clf_threshold self.gp_threshold = gp_threshold # Prepare GP Data if self.train_y_clf.size > 0: mask_gp = self.train_y_clf.flatten() > (self.train_y_clf.max() - self.gp_threshold) train_x_gp = self.train_x_clf[mask_gp] train_y_gp = self.train_y_clf[mask_gp] else: train_x_gp = self.train_x_clf train_y_gp = self.train_y_clf # Initialize GP using inheritance gp_init_kwargs = { 'train_x': train_x_gp, 'train_y': train_y_gp, 'noise': noise, 'kernel': kernel, 'optimizer': optimizer, 'optimizer_options': optimizer_options, 'kernel_variance_bounds': kernel_variance_bounds, 'lengthscale_bounds': lengthscale_bounds, 'lengthscales': lengthscales, 'kernel_variance': kernel_variance, 'lengthscale_prior': lengthscale_prior if lengthscale_prior is not None else "DSLP", 'kernel_variance_prior': kernel_variance_prior, 'tausq': tausq, 'tausq_bounds': tausq_bounds, 'param_names': param_names, } super().__init__(**gp_init_kwargs) # Initialize Classifier self.use_clf = self.clf_data_size >= self.clf_use_size self.clf_model_params = None self._clf_predict_func = None if self.use_clf: if train_clf_on_init: self.train_classifier() else: log.debug(f"Not enough data ({self.clf_data_size}) to use classifier (need {self.clf_use_size} points), or classifier type not set.")
[docs] def train_classifier(self): """Public method to train/retrain the classifier.""" # Check if classifier data size has reached the threshold if not self.use_clf: if self.clf_data_size >= self.clf_use_size: log.info(f"Classifier data size ({self.clf_data_size}) reached use size ({self.clf_use_size}). Will start using classifier.") self.use_clf = True if self.use_clf: self._train_classifier()
def _train_classifier(self): """Trains the classifier based on clf_type.""" start_time = time.time() # Determine labels for classifier training labels = np.where( self.train_y_clf.flatten() < self.train_y_clf.max() - self.clf_threshold, 0, 1 ) log.debug(f" Number of labels 0: {np.sum(labels == 0)}, 1: {np.sum(labels == 1)}") # Add method to handle if only class is present if np.all(labels == labels[0]): # If all labels are the same, we make sure not to use the classifier log.debug("All labels are identical. Not using classifier for the moment") self.use_clf = False return # Prepare kwargs for training kwargs = {} best_pt = self.train_x_clf[jnp.argmax(self.train_y_clf)] kwargs['best_pt'] = best_pt # Train classifier using the registered function # This now returns params, metrics, and predict_fn self.clf_params, self.clf_metrics, self._clf_predict_func = self.clf_train_fn( self.train_x_clf, labels, self.clf_settings, init_params=self.clf_params, **kwargs ) log.debug(f"Trained {self.clf_type.upper()} classifier on {self.clf_data_size} points in {time.time() - start_time:.2f}s") log.debug(f"Classifier metrics: {self.clf_metrics}") # Use debug for detailed metrics
[docs] def predict_mean_single(self,x): gp_mean = super().predict_mean_single(x) if not self.use_clf or self._clf_predict_func is None: return gp_mean clf_probs = self._clf_predict_func(x) return jnp.where(clf_probs >= self.probability_threshold, gp_mean, self.minus_inf)
[docs] def predict_var_single(self,x): var = super().predict_var_single(x) if not self.use_clf or self._clf_predict_func is None: return var clf_probs = self._clf_predict_func(x) return jnp.where(clf_probs >= self.probability_threshold, var, safe_noise_floor)
[docs] def predict_mean_batched(self,x): x = jnp.atleast_2d(x) return jax.vmap(self.predict_mean_single)(x)
[docs] def predict_var_batched(self,x): x = jnp.atleast_2d(x) return jax.vmap(self.predict_var_single)(x)
[docs] def predict_single(self,x): mean, var = super().predict_single(x) if not self.use_clf or self._clf_predict_func is None: return mean, var clf_probs = self._clf_predict_func(x) mean = jnp.where(clf_probs >= self.probability_threshold, mean, self.minus_inf) var = jnp.where(clf_probs >= self.probability_threshold, var, safe_noise_floor) return mean, var
[docs] def fantasy_var(self, new_x, mc_points,k_train_mc): """ Computes the fantasy variance, see gp.py for more details. Classifier logic could potentially be added here if needed. """ return super().fantasy_var(new_x, mc_points,k_train_mc)
[docs] def update(self, new_x, new_y): """ Updates the classifier and GP training sets. Retrains classifier/GP based on thresholds and steps. """ new_x = jnp.atleast_2d(new_x) new_y = jnp.atleast_2d(new_y) # Check for duplicates in data new_pts_to_add = [] new_vals_to_add = [] for i in range(new_x.shape[0]): if jnp.any(jnp.all(jnp.isclose(self.train_x_clf, new_x[i], atol=1e-6,rtol=1e-4), axis=1)): log.debug(f"Point {new_x[i]} already exists in the training set, not updating") else: new_pts_to_add.append(new_x[i]) new_vals_to_add.append(new_y[i]) if new_pts_to_add: new_pts_to_add = jnp.atleast_2d(jnp.array(new_pts_to_add)) new_vals_to_add = jnp.atleast_2d(jnp.array(new_vals_to_add)).reshape(-1, 1) self.train_x_clf = jnp.concatenate([self.train_x_clf, new_pts_to_add], axis=0) self.train_y_clf = jnp.concatenate([self.train_y_clf, new_vals_to_add], axis=0) mask_gp = self.train_y_clf.flatten() > (self.train_y_clf.max() - self.gp_threshold) self.train_x = self.train_x_clf[mask_gp] self.train_y = self.train_y_clf[mask_gp].reshape(-1, 1) self.y_std = jnp.std(self.train_y) if self.train_y.shape[0] > 1 else 1.0 self.y_mean = jnp.mean(self.train_y) self.train_y = (self.train_y - self.y_mean) / self.y_std self.recompute_cholesky() log.debug(f"Classifier data size: {self.train_y_clf.shape[0]}, GP data size: {self.train_y.shape[0]}")
[docs] def kernel(self,x1,x2,lengthscales,kernel_variance,noise,include_noise=True): """ Returns the kernel function used by the GP. """ return super().kernel(x1,x2,lengthscales,kernel_variance,noise,include_noise=include_noise)
[docs] def get_random_point(self,rng=None, nstd = None): rng = rng if rng is not None else get_numpy_rng() if self.use_clf: if nstd is not None: threshold = get_threshold_for_nsigma(nstd,self.ndim) else: threshold = self.clf_threshold pts_idx = self.train_y_clf.flatten() > self.train_y_clf.max() - threshold # Sample a random point from the filtered points valid_indices = jnp.where(pts_idx)[0] chosen_index = rng.choice(valid_indices, size=1)[0] pt = self.train_x_clf[chosen_index] log.debug(f"Random point sampled with value {self.train_y_clf[chosen_index]}") else: pt = super().get_random_point(rng=rng, nstd=nstd) return pt
[docs] def state_dict(self): """ Returns a dictionary containing the complete state of the GPwithClassifier. This can be used for saving, loading, or copying the GPwithClassifier. Returns ------- state: dict Dictionary containing all necessary information to reconstruct the GPwithClassifier """ # Start with the base GP state state = super().state_dict() # Add classifier-specific data classifier_state = { # Classifier training data 'train_x_clf': np.array(self.train_x_clf), 'train_y_clf': np.array(self.train_y_clf), # Classifier configuration 'clf_type': self.clf_type, 'clf_settings': self.clf_settings, 'clf_use_size': self.clf_use_size, 'clf_update_step': self.clf_update_step, 'probability_threshold': self.probability_threshold, 'minus_inf': self.minus_inf, 'clf_threshold': self.clf_threshold, 'gp_threshold': self.gp_threshold, 'use_clf': self.use_clf, # Classifier state 'clf_params': self.clf_params, 'clf_metrics': self.clf_metrics, # Class identifier 'gp_class': 'GPwithClassifier' } # Update the state with classifier-specific data state.update(classifier_state) return state
[docs] @classmethod def from_state_dict(cls, state): """ Creates a GPwithClassifier instance from a state dictionary. Arguments --------- state: dict State dictionary returned by state_dict() Returns ------- gp_clf: GPwithClassifier The reconstructed GPwithClassifier object """ # Create GPwithClassifier instance gp_clf = cls( train_x=state['train_x_clf'], train_y=state['train_y_clf'], clf_type=state['clf_type'], clf_settings=state['clf_settings'], clf_use_size=state['clf_use_size'], clf_update_step=state['clf_update_step'], probability_threshold=state['probability_threshold'], minus_inf=state['minus_inf'], clf_threshold=state['clf_threshold'], gp_threshold=state['gp_threshold'], noise=state['noise'], kernel=state['kernel_name'], optimizer=state['optimizer_method'], optimizer_options=state['optimizer_options'], kernel_variance_bounds=state['kernel_variance_bounds'], lengthscale_bounds=state['lengthscale_bounds'], lengthscales=state['lengthscales'], kernel_variance=state['kernel_variance'], kernel_variance_prior=state.get('kernel_variance_prior_spec'), lengthscale_prior=state.get('lengthscale_prior_spec'), tausq=state.get('tausq', 1.0), tausq_bounds=state.get('tausq_bounds', [-4, 4]), train_clf_on_init=state.get('train_clf_on_init', True), ) # # Restore computed state if available # if state.get('cholesky') is not None: # gp_clf.cholesky = jnp.array(state['cholesky']) # if state.get('alphas') is not None: # gp_clf.alphas = jnp.array(state['alphas']) # Restore classifier state gp_clf.use_clf = state['use_clf'] gp_clf.clf_params = state.get('clf_params') gp_clf.clf_metrics = state.get('clf_metrics', {}) # Regenerate prediction function if classifier parameters exist if gp_clf.clf_params is not None: if gp_clf.clf_type == 'svm': gp_clf._clf_predict_func = gp_clf.clf_predict_fn(gp_clf.clf_params) elif gp_clf.clf_type == 'nn': gp_clf._clf_predict_func = gp_clf.clf_predict_fn(gp_clf.clf_params, gp_clf.clf_settings) elif gp_clf.clf_type == 'ellipsoid': d = gp_clf.train_x_clf.shape[1] gp_clf._clf_predict_func = gp_clf.clf_predict_fn( gp_clf.clf_params, gp_clf.clf_settings, d ) return gp_clf
[docs] def save(self, filename='gp'): """ Save the GPwithClassifier state to a file using state_dict. Arguments --------- filename: str The filename to save to (with or without .npz extension). Default is 'gp'. """ if not filename.endswith('.npz'): filename += '.npz' state = self.state_dict() np.savez(filename, **state) log.info(f"Saved GPwithClassifier state to {filename}")
[docs] @classmethod def load(cls, filename, **kwargs): """ Loads a GPwithClassifier from a file Arguments --------- filename: str The name of the file to load the GPwithClassifier from (with or without .npz extension) **kwargs: Additional keyword arguments to pass to the GPwithClassifier constructor Returns ------- gp_clf: GPwithClassifier The loaded GPwithClassifier object """ if not filename.endswith('.npz'): filename += '.npz' try: data = np.load(filename, allow_pickle=True) except FileNotFoundError: raise FileNotFoundError(f"Could not find file {filename}") # Convert arrays back to the expected format state = {} for key in data.files: value = data[key] if isinstance(value, np.ndarray) and value.shape == (): # Handle scalar arrays state[key] = value.item() else: state[key] = value # Apply any override kwargs state.update(kwargs) # Use from_state_dict for loading gp_clf = cls.from_state_dict(state) log.info(f"Loaded GPwithClassifier from {filename} with {gp_clf.train_x.shape[0]} training points") return gp_clf
[docs] def copy(self): """ Creates a deep copy of the GPwithClassifier using state_dict. Returns ------- gp_clf_copy: GPwithClassifier A deep copy of the current GPwithClassifier """ state = self.state_dict() return self.__class__.from_state_dict(state)
@property def clf_data_size(self): """Size of the classifier's training inputs.""" return self.train_x_clf.shape[0] @property def npoints(self): return self.train_x_clf.shape[0]