Source code for BOBE.kernels

"""
Kernel implementations for Gaussian Process models.

All kernels inherit from the base Kernel class and implement the covariance() method.
JAX JIT compilation is handled at higher levels (acquisition functions, optimization).
"""

from abc import ABC, abstractmethod
from math import sqrt
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

# Constants for Matérn kernel
sqrt2 = sqrt(2.)
sqrt3 = sqrt(3.)
sqrt5 = sqrt(5.)


[docs] class Kernel(ABC): """ Abstract base class for all kernels in BOBE. Attributes ---------- lengthscales : jnp.ndarray Lengthscale parameters for each dimension, shape (D,) kernel_variance : float Overall variance/amplitude of the kernel noise : float Observation noise level """
[docs] def __init__(self, lengthscales, kernel_variance, noise=1e-8): """ Initialize kernel with hyperparameters. Parameters ---------- lengthscales : jnp.ndarray Lengthscale for each input dimension kernel_variance : float Kernel variance/amplitude parameter noise : float, optional Noise level added to diagonal. Default is 1e-8. """ self.lengthscales = jnp.array(lengthscales) self.kernel_variance = kernel_variance self.noise = noise
[docs] def sq_dist(self, xa, xb): """ Compute squared Euclidean distance between two sets of points. This utility method is used by many kernel implementations. Parameters ---------- xa : jnp.ndarray First set of points, shape (n1, D) xb : jnp.ndarray Second set of points, shape (n2, D) Returns ------- sq_dist : jnp.ndarray Squared distances, shape (n1, n2) """ return jnp.sum(jnp.square(xa[:, None, :] - xb[None, :, :]), axis=-1)
[docs] @abstractmethod def covariance(self, xa, xb, include_noise=True): """ Compute covariance matrix between two sets of points. Parameters ---------- xa : jnp.ndarray First set of points, shape (n1, D) xb : jnp.ndarray Second set of points, shape (n2, D) include_noise : bool, optional Whether to add noise to diagonal (only when xa is xb). Default is True. Returns ------- K : jnp.ndarray Covariance matrix of shape (n1, n2) """ pass
[docs] def diagonal(self, x, include_noise=True): """ Compute only the diagonal of the kernel matrix K(x,x). For stationary kernels, the diagonal is constant: kernel_variance (+ noise). Override this method if your kernel has a non-constant diagonal. Parameters ---------- x : jnp.ndarray Points at which to compute diagonal, shape (n, D) include_noise : bool, optional Whether to include noise in diagonal. Default is True. Returns ------- diag : jnp.ndarray Diagonal values, shape (n,) """ diag = self.kernel_variance * jnp.ones(x.shape[0]) if include_noise: diag += self.noise return diag
[docs] def update_hyperparams(self, lengthscales=None, kernel_variance=None, noise=None): """ Update kernel hyperparameters. Parameters ---------- lengthscales : jnp.ndarray, optional New lengthscale values kernel_variance : float, optional New kernel variance noise : float, optional New noise level """ if lengthscales is not None: self.lengthscales = jnp.array(lengthscales) if kernel_variance is not None: self.kernel_variance = kernel_variance if noise is not None: self.noise = noise
[docs] def __call__(self, xa, xb, include_noise=True): """Convenience method - same as covariance()""" return self.covariance(xa, xb, include_noise=include_noise)
[docs] class RBFKernel(Kernel): """ Radial Basis Function (RBF) / Squared Exponential kernel. k(x, x') = σ² * exp(-0.5 * ||x - x'||²/ℓ²) where σ² is kernel_variance and ℓ is lengthscale. """
[docs] def covariance(self, xa, xb, include_noise=True): """ Compute RBF covariance matrix. Parameters ---------- xa : jnp.ndarray First set of input points, shape (n1, d). xb : jnp.ndarray Second set of input points, shape (n2, d). include_noise : bool, optional Whether to include noise on diagonal. Default is True. Returns ------- jnp.ndarray Kernel matrix of shape (n1, n2). """ # Scale inputs by lengthscales xa_scaled = xa / self.lengthscales xb_scaled = xb / self.lengthscales # Compute squared distances sq_dist = self.sq_dist(xa_scaled, xb_scaled) # Apply RBF kernel K = self.kernel_variance * jnp.exp(-0.5 * sq_dist) # Add noise to diagonal if needed if include_noise: K += self.noise * jnp.eye(K.shape[0]) return K
[docs] class MaternKernel(Kernel): """ Matérn-5/2 kernel. k(x, x') = σ² * (1 + √5*d + 5*d²/3) * exp(-√5*d) where d = ||x - x'||/ℓ, σ² is kernel_variance, and ℓ is lengthscale. """
[docs] def covariance(self, xa, xb, include_noise=True): """ Compute Matérn-5/2 covariance matrix. Parameters ---------- xa : jnp.ndarray First set of input points, shape (n1, d). xb : jnp.ndarray Second set of input points, shape (n2, d). include_noise : bool, optional Whether to include noise on diagonal. Default is True. Returns ------- jnp.ndarray Kernel matrix of shape (n1, n2). """ # Scale inputs by lengthscales xa_scaled = xa / self.lengthscales xb_scaled = xb / self.lengthscales # Compute squared distances dsq = self.sq_dist(xa_scaled, xb_scaled) # Safe sqrt to avoid division by zero d = jnp.sqrt(jnp.where(dsq < 1e-30, 1e-30, dsq)) # Matérn-5/2 formula exp_term = jnp.exp(-sqrt5 * d) poly_term = 1. + d * (sqrt5 + d * 5. / 3.) K = self.kernel_variance * poly_term * exp_term # Add noise to diagonal if needed if include_noise: K += self.noise * jnp.eye(K.shape[0]) return K