Source code for BOBE.clf

# interfaces and routines for some classifiers
# SVM, Neural Networks, Ellipsoidal bound, etc.

import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from sklearn.svm import SVC
from typing import Callable, Dict, Any, Union, List, Optional, Tuple
from functools import partial
from .utils.log import get_logger
from .utils.seed import get_numpy_rng, get_new_jax_key
log = get_logger("clf")

try:
    import optax
    OPTAX_AVAILABLE = True
except ImportError:
    OPTAX_AVAILABLE = False
    optax = None
    log.debug("optax is not available. NN and Ellipsoid classifiers will require it.")

try:
    from flax import linen as nn
    FLAX_AVAILABLE = True
except ImportError:
    FLAX_AVAILABLE = False
    nn = None
    log.debug("Flax is not available. Only SVM classifier will be available.")


# -----------------------------------------------------------------------------
# Standalone training and prediction functions for classifiers
# -----------------------------------------------------------------------------

[docs] def train_svm_classifier(X, Y, settings = {}, init_params=None, **kwargs): """Train SVM classifier and return parameters, metrics, and predict function.""" gamma = settings.get('gamma', 'scale') C = settings.get('C', 1e7) kernel = settings.get('kernel', 'rbf') rng = get_numpy_rng() random_state = int(rng.integers(0, 2**31 - 1)) clf = SVC(kernel=kernel, gamma=gamma, C=C, random_state=random_state) clf.fit(X, Y) support_vectors = clf.support_vectors_ dual_coef = clf.dual_coef_[0] # convert to 1D array intercept = float(clf.intercept_[0]) gamma_eff = float(clf._gamma) # note: this is the effective gamma value used by scikit-learn # convert to jax arrays support_vectors = jnp.array(support_vectors) dual_coef = jnp.array(dual_coef) metrics = { 'n_support_vectors': len(support_vectors), 'gamma': f"{gamma_eff:.2e}", 'C': f"{C:.2e}", 'intercept': f"{intercept:.2e}", } params = { 'support_vectors': support_vectors, 'dual_coef': dual_coef, 'intercept': intercept, 'gamma_eff': gamma_eff } # Create predict function predict_fn = jax.jit(partial(svm_predict_proba, support_vectors=support_vectors, dual_coef=dual_coef, intercept=intercept, gamma=gamma_eff)) return params, metrics, predict_fn
[docs] def get_svm_predict_proba_fn(params): """Get prediction function for SVM classifier from parameters (for loading from file).""" support_vectors = params['support_vectors'] dual_coef = params['dual_coef'] intercept = params['intercept'] gamma = params['gamma_eff'] return jax.jit(partial(svm_predict_proba, support_vectors=support_vectors, dual_coef=dual_coef, intercept=intercept, gamma=gamma))
# ----------------------------------------------------------------------------- # Neural Network Classifier (currently in development)
[docs] def train_nn_classifier(X, Y, settings = {}, init_params=None, **kwargs): """Train neural network classifier and return parameters, metrics, and predict function.""" if not FLAX_AVAILABLE or not OPTAX_AVAILABLE: raise ImportError("Flax and optax are required for NN classifier. " "Install with: pip install 'BOBE[nn]'" ) # Create model with settings label_size = X.shape[0] if label_size < 500: settings.update({'hidden_dims': [32, 32]}) settings.update({'batch_size': 64}) else: settings.update({'hidden_dims': [32, 32]}) settings.update({'batch_size': 128}) model = MLPClassifier(**settings) # Train with multiple restarts params, metrics = train_nn_multiple_restarts( model=model, x=X, y=Y, init_params=init_params ) # Create predict function def predict_proba_fn(x): logits = model.apply(params, x, train=False) return jax.nn.sigmoid(logits.squeeze(-1)) predict_fn = jax.jit(predict_proba_fn) return params, metrics, predict_fn
[docs] def get_nn_predict_proba_fn(params, settings = {}, **kwargs): """Get prediction function for NN classifier from parameters (for loading from file).""" # Recreate model with same settings to get the apply function model = MLPClassifier(**settings) def predict_proba_fn(x): logits = model.apply(params, x, train=False) return jax.nn.sigmoid(logits.squeeze(-1)) return jax.jit(predict_proba_fn)
[docs] def train_ellipsoid_classifier(X, Y, settings = {}, init_params=None, **kwargs): """Train ellipsoid classifier and return parameters, metrics, and predict function.""" if not FLAX_AVAILABLE or not OPTAX_AVAILABLE: raise ImportError( "Flax and optax are required for Ellipsoid classifier. " "Install with: pip install 'BOBE[nn]'" ) d = X.shape[1] mu = kwargs.get('best_pt', 0.5*jnp.ones(d)) # label_size = X.shape[0] # if label_size < 500: # settings.update({'batch_size': 64}) # else: # settings.update({'batch_size': 128}) # Create model with settings model = EllipsoidClassifier(d=d, mu=mu, **settings) # Train with multiple restarts params, metrics = train_ellipsoid_multiple_restarts( model=model, x=X, y=Y, init_params=init_params, ) def predict_proba_fn(x): logits = model.apply(params, x, train=False) return jax.nn.sigmoid(logits.squeeze()) predict_fn = jax.jit(predict_proba_fn) return params, metrics, predict_fn
[docs] def get_ellipsoid_predict_proba_fn(params, settings, d, **kwargs): """Get prediction function for ellipsoid classifier from parameters (for loading from file).""" mu = kwargs.get('best_pt', 0.5*jnp.ones(d)) model = EllipsoidClassifier(d=d, mu=mu, **settings) def predict_proba_fn(x): logits = model.apply(params, x, train=False) return jax.nn.sigmoid(logits.squeeze()) return jax.jit(predict_proba_fn)
# Dictionary mapping classifier types to their functions CLASSIFIER_REGISTRY = { 'svm': { 'train_fn': train_svm_classifier, 'predict_fn': get_svm_predict_proba_fn, }, 'nn': { 'train_fn': train_nn_classifier, 'predict_fn': get_nn_predict_proba_fn, }, 'ellipsoid': { 'train_fn': train_ellipsoid_classifier, 'predict_fn': get_ellipsoid_predict_proba_fn, } } # ----------------------------------------------------------------------------- # SVM prediction functions # -----------------------------------------------------------------------------
[docs] def svm_predict(x: jnp.ndarray, support_vectors: jnp.ndarray, dual_coef: jnp.ndarray, intercept: float, gamma: float): """ Compute the decision function for SVM with RBF kernel. Arguments: x: Input data point, shape (n_features,) support_vectors: JAX array of support vectors, shape (n_sv, n_features) dual_coef: JAX array of dual coefficients, shape (n_sv,) intercept: Scalar bias term. gamma: RBF kernel gamma parameter. Returns: Decision function value (scalar). Sign of this value gives the predicted class. """ # Compute squared Euclidean distances between x and each support vector. diff = support_vectors - x # shape (n_sv, n_features) norm_sq = jnp.sum(diff ** 2, axis=1) # shape (n_sv,) # Compute RBF kernel values. kernel_vals = jnp.exp(-gamma * norm_sq) # shape (n_sv,) # Compute the decision function. decision = jnp.sum(dual_coef * kernel_vals) + intercept return decision
[docs] def svm_predict_proba(x: jnp.ndarray, support_vectors: jnp.ndarray, dual_coef: jnp.ndarray, intercept: float, gamma: float): decision = svm_predict(x, support_vectors, dual_coef, intercept, gamma) return jnp.where(decision >= 0, 1.0, 0.0) # Binary classification: 1 if decision >= 0, else 0
# ----------------------------------------------------------------------------- # Neural Network Classifiers # ----------------------------------------------------------------------------- # Common training utilities
[docs] def train_with_restarts( train_fn: Callable, x: jnp.ndarray, y: jnp.ndarray, n_restarts: int = 2, init_params = None, **train_kwargs ) -> Tuple[Dict, Dict]: """ Train model with multiple restarts using the entire dataset. Args: train_fn: Training function that returns (params, metrics) x: (N, d) features y: (N,) labels n_restarts: number of random restarts init_params: initial parameters for first restart **train_kwargs: passed to train_fn """ best_train_loss = jnp.inf best_params = None best_metrics = {} for i in range(n_restarts): # Use initial params for first restart, None for others restart_init_params = init_params if i == 0 else None if i == 0 and init_params is not None: log.debug(f"[Restart {i+1}/{n_restarts}] Using provided initial parameters") elif i > 0: log.debug(f"[Restart {i+1}/{n_restarts}] Using random initialization") # Use entire dataset for training (train_fn pulls fresh keys internally) params, metrics = train_fn( x_train=x, y_train=y, init_params=restart_init_params, **train_kwargs ) train_loss = float(metrics['train_loss']) if train_loss < best_train_loss: best_train_loss = train_loss best_params = params best_metrics = metrics log.debug(f"[Restart {i+1}/{n_restarts}] New best train_loss: {train_loss:.4e}") log.debug(f"[Training] Best model selected with train_loss = {best_train_loss:.4e}") return best_params, best_metrics
# Neural Network Classifier if FLAX_AVAILABLE: class MLPClassifier(nn.Module): hidden_dims: list = (32, 32) dropout_rate: float = 0.1 lr: float = 1e-3 weight_decay: float = 1e-4 n_epochs: int = 1000 batch_size: int = 128 early_stop_patience: int = 50 n_restarts: int = 2 val_frac: float = 0.2 @nn.compact def __call__(self, x, train: bool = False): for h in self.hidden_dims: x = nn.Dense(h)(x) x = nn.relu(x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = nn.Dense(1)(x) return x else: MLPClassifier = None
[docs] def train_nn( model: MLPClassifier, x_train: jnp.ndarray, y_train: jnp.ndarray, init_params=None, **kwargs ): """Simplified NN training using entire dataset""" N, d = x_train.shape rng_opt = get_numpy_rng() # Handle initialization if init_params is not None: params = init_params else: params = model.init(get_new_jax_key(), jnp.ones((1, d)), train=True) optimizer = optax.adamw(model.lr, weight_decay=model.weight_decay) opt_state = optimizer.init(params) @jax.jit def loss_fn(params, batch_x, batch_y, rng): logits = model.apply(params, batch_x, train=True, rngs={"dropout": rng}) return optax.sigmoid_binary_cross_entropy(logits.squeeze(-1), batch_y).mean() @jax.jit def train_step(params, opt_state, batch_x, batch_y, rng): grads = jax.grad(loss_fn)(params, batch_x, batch_y, rng) updates, opt_state = optimizer.update(grads, opt_state, params) return optax.apply_updates(params, updates), opt_state x_np, y_np = np.array(x_train), np.array(y_train) steps = max(1, x_train.shape[0] // model.batch_size) for epoch in range(model.n_epochs): perm_train = rng_opt.permutation(x_train.shape[0]) for i in range(steps): idx = perm_train[i*model.batch_size:(i+1)*model.batch_size] bx = jnp.array(x_np[idx]) by = jnp.array(y_np[idx]) params, opt_state = train_step(params, opt_state, bx, by, get_new_jax_key()) # Compute final training loss final_train_loss = loss_fn(params, x_train, y_train, get_new_jax_key()) metrics = { 'train_loss': f"{float(final_train_loss):.2e}", 'epochs': epoch + 1, } return params, metrics
[docs] def train_nn_multiple_restarts(model: MLPClassifier, x: jnp.ndarray, y: jnp.ndarray, **kwargs): """Wrapper for NN training with restarts""" return train_with_restarts(partial(train_nn, model), x, y, n_restarts=model.n_restarts, **kwargs)
# Ellipsoid Classifier with center at best fit point if FLAX_AVAILABLE: class EllipsoidClassifier(nn.Module): d: int mu: jnp.ndarray init_scale: float = 0.1 lr: float = 1e-2 weight_decay: float = 1e-4 n_epochs: int = 1000 batch_size: int = 64 patience: int = 25 n_restarts: int = 2 val_frac: float = 0.1 def setup(self): tril = self.d * (self.d + 1) // 2 self.flat_L = self.param("flat_L", nn.initializers.normal(self.init_scale), (tril,)) self.alpha = self.param("alpha", nn.initializers.ones, ()) self.beta = self.param("beta", nn.initializers.zeros, ()) def _unpack_L(self): L_matrix = jnp.zeros((self.d, self.d)) tril_indices = jnp.tril_indices(self.d) rows, cols = tril_indices diagonal_mask = rows == cols flat_L_processed = jnp.where(diagonal_mask, nn.softplus(self.flat_L) + 1e-4, self.flat_L) return L_matrix.at[tril_indices].set(flat_L_processed) @nn.compact def __call__(self, x, train: bool = False): L = self._unpack_L() diff = x - self.mu md2 = jnp.einsum("...i,ij,...j->...", diff, L @ L.T, diff) logit = -self.alpha * md2 + self.beta return logit else: EllipsoidClassifier = None
[docs] def train_ellipsoid( model: EllipsoidClassifier, x_train: jnp.ndarray, y_train: jnp.ndarray, init_params=None, **kwargs ): """Simplified ellipsoid training using entire dataset""" rng = get_numpy_rng() # Handle initialization if init_params is not None: params = init_params else: params = model.init(get_new_jax_key(), x_train) optimizer = optax.adamw(model.lr, weight_decay=model.weight_decay) opt_state = optimizer.init(params) @jax.jit def loss_fn(params, batch_x, batch_y): logits = model.apply(params, batch_x, train=False) return optax.sigmoid_binary_cross_entropy(logits, batch_y).mean() @jax.jit def train_step(params, opt_state, bx, by): grads = jax.grad(loss_fn)(params, bx, by) updates, opt_state = optimizer.update(grads, opt_state, params) return optax.apply_updates(params, updates), opt_state x_np, y_np = np.array(x_train), np.array(y_train) steps = max(1, x_train.shape[0] // model.batch_size) for epoch in range(model.n_epochs): perm_train = rng.permutation(x_train.shape[0]) for i in range(steps): idx = perm_train[i*model.batch_size:(i+1)*model.batch_size] bx = jnp.array(x_np[idx]) by = jnp.array(y_np[idx]) params, opt_state = train_step(params, opt_state, bx, by) # Compute final training loss final_train_loss = loss_fn(params, x_train, y_train) metrics = { 'train_loss': f"{float(final_train_loss):.2e}", 'epochs': epoch + 1, } return params, metrics
[docs] def train_ellipsoid_multiple_restarts(model: EllipsoidClassifier, x: jnp.ndarray, y: jnp.ndarray, **kwargs): """Wrapper for ellipsoid training with restarts""" return train_with_restarts(partial(train_ellipsoid, model), x, y, n_restarts=model.n_restarts, **kwargs)