import numpy as np
import jax
import jax.numpy as jnp
from typing import Callable, Dict, List, Any, Union, Optional, Tuple
from BOBE.utils.seed import set_global_seed, get_numpy_rng, get_new_jax_key
from BOBE.utils.log import get_logger
from BOBE.gp import GP
from BOBE.clf_gp import GPwithClassifier
from BOBE.likelihood import Likelihood, CobayaLikelihood
log = get_logger('pool')
try:
from mpi4py import MPI
IS_MPI_AVAILABLE = True
except:
MPI = None
IS_MPI_AVAILABLE = False
[docs]
class MPI_Pool:
"""
Enhanced MPI Pool with support for managing worker state and multiple task types.
This pool implements a master-worker pattern where workers enter a waiting loop
and the master dispatches tasks dynamically. Workers automatically participate
after initialization and don't need explicit management in user code.
"""
TASK_OBJECTIVE_EVAL = 0
TASK_GP_FIT = 1
TASK_ACQUISITION_OPT = 3
TASK_COBAYA_INIT = 4
TASK_CLEAR_JAX_CACHES = 5
TASK_INIT = 99
TASK_EXIT = 100
[docs]
def __init__(self, dynamic_dispatch: bool = False):
"""Initializes the pool based on whether MPI is available and active."""
if IS_MPI_AVAILABLE:
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.is_mpi = self.size > 1
self.is_main_process = self.rank == 0
self.is_worker = self.rank > 0
else:
self.comm = None
self.rank = 0
self.size = 1
self.is_mpi = False
self.is_main_process = True
self.is_worker = False
# Track if workers are in waiting loop
self._workers_active = False
# Static vs dynamic task dispatch
self.dynamic_dispatch = dynamic_dispatch
[docs]
def worker_wait(self, likelihood: Likelihood, gp: Union[GP, GPwithClassifier] = None, seed: Optional[int] = None):
"""
Main loop for worker processes. Workers wait for tasks from master and execute them.
This method should be called by worker processes after initialization is complete.
It enters an infinite loop waiting for tasks until TASK_EXIT is received.
Parameters
----------
likelihood : Likelihood
The likelihood object for evaluating objective function.
gp : GP or GPwithClassifier, optional
The GP object, can be updated via state_dict broadcasts.
seed : int, optional
Random seed for the worker. If provided, will be offset by rank.
Notes
-----
This method only executes for worker processes (rank > 0).
The master process (rank 0) immediately returns.
"""
if not self.is_worker:
return
log.info(f"Worker {self.rank} entering wait loop")
if seed is not None:
seed = seed + self.rank
set_global_seed(seed)
rng = get_numpy_rng()
# Store likelihood and GP for task execution
self._likelihood = likelihood
self._gp = gp
# Wire format: every task message is (task_type, data) where
# data = (payload, task_index) for TASK_OBJECTIVE_EVAL / TASK_COBAYA_INIT / TASK_GP_FIT
# data = None for TASK_EXIT / TASK_CLEAR_JAX_CACHES
while True:
# Wait for task from master
task_data = self.comm.recv(source=0, tag=MPI.ANY_TAG)
task_type, data = task_data
try:
if task_type == self.TASK_OBJECTIVE_EVAL:
# Evaluate likelihood at a point
point, task_index = data
result = self._likelihood(point)
self.comm.send((result, task_index), dest=0)
elif task_type == self.TASK_GP_FIT:
# Fit GP with given starting points
payload, task_index = data
state_dict = payload['state_dict']
fit_params = payload['fit_params']
use_clf = payload.get('use_clf', False)
# Reconstruct GP from state dict
if use_clf:
worker_gp = GPwithClassifier.from_state_dict(state_dict)
else:
worker_gp = GP.from_state_dict(state_dict)
# Fit GP and return results
fit_results = worker_gp.fit(**fit_params)
self.comm.send((fit_results, task_index), dest=0)
elif task_type == self.TASK_COBAYA_INIT:
# Get initial point from Cobaya reference prior
_, task_index = data
pt, logpost = self._likelihood._get_single_valid_point(rng)
self.comm.send(((pt, logpost), task_index), dest=0)
elif task_type == self.TASK_CLEAR_JAX_CACHES:
# Clear JAX caches on worker
jax.clear_caches()
elif task_type == self.TASK_EXIT:
log.info(f"Worker {self.rank} exiting")
break
except Exception as e:
import traceback
log.error(f"Worker {self.rank} error: {e}")
log.error(traceback.format_exc())
# Send error back to master with task index (skip for tasks with no index)
if data is None:
pass # TASK_EXIT / TASK_CLEAR_JAX_CACHES — nothing to report
else:
_, task_index = data
self.comm.send(("error", str(e), task_index), dest=0)
return
def _dynamic_distribute(self, tasks: List[Any], task_type: int) -> List[Any]:
"""
MASTER-ONLY METHOD: Distributes tasks to workers using dynamic scheduling.
This is a generic utility that sends tasks one by one to available workers
and collects the results in order. Workers must be in worker_wait() loop.
Parameters
----------
tasks : list
List of task data to distribute to workers.
task_type : int
Type of task (TASK_OBJECTIVE_EVAL, TASK_GP_FIT, etc.).
Returns
-------
list
Results from all tasks in the same order as input tasks.
"""
if not self.is_main_process or not self.is_mpi:
raise RuntimeError("_dynamic_distribute is designed for the master process in MPI mode.")
n_tasks = len(tasks)
if n_tasks == 0:
return []
results = [None] * n_tasks
task_index = 0
tasks_in_progress = 0
# Initial distribution to all available workers
for worker_rank in range(1, self.size):
if task_index < n_tasks:
payload = (tasks[task_index], task_index)
self.comm.send((task_type, payload), dest=worker_rank)
task_index += 1
tasks_in_progress += 1
# Receive results and distribute remaining tasks
while tasks_in_progress > 0:
status = MPI.Status()
# The worker returns (result, original_index) or (error, msg, original_index)
response = self.comm.recv(source=MPI.ANY_SOURCE, status=status)
worker_rank = status.Get_source()
if len(response) == 3 and response[0] == "error":
_, msg, original_index = response
log.error(f"Worker {worker_rank} failed on task {original_index}: {msg}")
# Propagate the error to be handled by the caller
raise RuntimeError(f"Worker {worker_rank} failed: {msg}")
result, original_index = response
results[original_index] = result
tasks_in_progress -= 1
# If there are more tasks, send one to the newly freed worker
if task_index < n_tasks:
payload = (tasks[task_index], task_index)
self.comm.send((task_type, payload), dest=worker_rank)
task_index += 1
tasks_in_progress += 1
return results
def _static_distribute(self, tasks: List[Any], task_type: int, local_fn: Callable = None) -> List[Any]:
"""
MASTER-ONLY METHOD: Distributes tasks statically (round-robin) across all ranks.
Master participates in evaluation. Task i is always assigned to rank i % size,
making the assignment deterministic given (seed, nprocs). Uses the same
individual task messages as _dynamic_distribute — no extra task type needed.
Parameters
----------
tasks : list
List of task data to distribute.
task_type : int
Type of task (TASK_OBJECTIVE_EVAL or TASK_COBAYA_INIT).
local_fn : callable
Function for the master to evaluate its own tasks. Must accept a single
task element and return its result.
Returns
-------
list
Results from all tasks in the same order as input tasks.
"""
if not self.is_main_process or not self.is_mpi:
raise RuntimeError("_static_distribute is designed for the master process in MPI mode.")
n_tasks = len(tasks)
if n_tasks == 0:
return []
results = [None] * n_tasks
master_tasks = [] # (task, global_index) pairs for master to run locally
# Send each task to its pre-assigned rank (round-robin)
for i, task in enumerate(tasks):
assigned_rank = i % self.size
if assigned_rank == 0:
master_tasks.append((task, i))
else:
self.comm.send((task_type, (task, i)), dest=assigned_rank)
# Master evaluates its own tasks locally
for task, idx in master_tasks:
results[idx] = local_fn(task)
# Collect one response per non-master task
n_worker_tasks = n_tasks - len(master_tasks)
for _ in range(n_worker_tasks):
response = self.comm.recv(source=MPI.ANY_SOURCE)
if len(response) == 3 and response[0] == "error":
_, msg, original_index = response
log.error(f"Worker failed on task {original_index}: {msg}")
raise RuntimeError(f"Worker failed on task {original_index}: {msg}")
result, original_index = response
results[original_index] = result
return results
# REFACTORED: Now uses the central utility
[docs]
def run_map_objective(self, function: Callable, tasks: List[Any]) -> np.ndarray:
"""
Maps a function over a list of tasks in parallel.
In MPI mode, distributes tasks to workers dynamically. Workers must be
in worker_wait() loop. In serial mode, evaluates locally.
Parameters
----------
function : callable
The objective/likelihood function to evaluate. Only used in serial mode.
tasks : list or array-like
List of input points to evaluate, shape (n_tasks, ndim).
Returns
-------
np.ndarray
Array of results, shape (n_tasks,) or (n_tasks, 1).
"""
if not self.is_main_process:
return None
if not self.is_mpi:
# Serial execution if not in MPI mode
results = [function(task) for task in tasks]
elif self.dynamic_dispatch:
results = self._dynamic_distribute(tasks, self.TASK_OBJECTIVE_EVAL)
else:
results = self._static_distribute(tasks, self.TASK_OBJECTIVE_EVAL, local_fn=function)
return np.array(results)
[docs]
def gp_fit(self, gp: GP, maxiters=1000, n_restarts=8, rng=None, use_pool=True):
"""
Orchestrates a parallel GP hyperparameter fit across MPI processes.
Distributes multiple random restarts across workers for hyperparameter
optimization and selects the best result.
Parameters
----------
gp : GP or GPwithClassifier
Gaussian Process model to fit.
maxiters : int, optional
Maximum iterations for each optimization. Default is 1000.
n_restarts : int, optional
Number of random restarts for optimization. Default is 8.
In MPI mode, adjusted to at least one restart per process.
rng : np.random.Generator, optional
Random number generator for initial points. If None, creates new one.
use_pool : bool, optional
Whether to use MPI pool for parallelization. Default is True.
Returns
-------
dict or None
Best fit result for master process, None for workers.
"""
if self.is_worker:
return None
# Adjust n_restarts to be at least equal to the number of processes
if self.is_mpi and use_pool:
n_restarts = max(self.size, n_restarts)
n_restarts = min(n_restarts, 2 * self.size)
rng = np.random.default_rng() if rng is None else rng
n_params = gp.hyperparam_bounds.shape[1] # hp bounds are (2, n_params) shaped
# Prepare initial parameters for all restarts
init_params = jnp.log(gp.get_hyperparams())
if n_restarts > 1:
x0_random = rng.uniform(
gp.hyperparam_bounds[0],
gp.hyperparam_bounds[1],
size=(n_restarts - 1, n_params)
)
x0 = np.vstack([init_params, x0_random])
else:
x0 = np.atleast_2d(init_params)
# If not running in MPI or use_pool=False, call the GP's local fit method
if not self.is_mpi or not use_pool:
log.info(f"Running serial GP fit with {n_restarts} restarts.")
results = gp.fit(x0=x0, maxiter=maxiters)
gp.update_hyperparams(results['params'])
return results
# MPI Parallel Block - distribute restarts across workers
log.info(f"Running parallel GP fit with {n_restarts} restarts across {self.size} MPI processes.")
# Split initial points across processes
x0_chunks = np.array_split(x0, self.size)
state_dict = gp.state_dict()
# Send tasks to workers
for i in range(1, self.size):
worker_x0 = x0_chunks[i]
fit_params = {'x0': worker_x0, 'maxiter': maxiters}
payload = {
'state_dict': state_dict,
'fit_params': fit_params,
'use_clf': isinstance(gp, GPwithClassifier)
}
self.comm.send((self.TASK_GP_FIT, (payload, i)), dest=i)
# Master does its share of work
master_x0 = x0_chunks[0]
master_result = gp.fit(x0=master_x0, maxiter=maxiters)
# Collect results from workers
all_results = [master_result]
for i in range(1, self.size):
worker_result, _ = self.comm.recv(source=i)
all_results.append(worker_result)
# Select best result and update GP
best_result = max(all_results, key=lambda r: r['mll'])
best_params = best_result['params']
gp.update_hyperparams(best_params)
return best_result
[docs]
def get_cobaya_initial_points(self, likelihood: CobayaLikelihood, n_points: int, rng=None) -> List[Tuple]:
"""
Gets initial points from the Cobaya reference prior in parallel.
Distributes the generation of Cobaya initial points across workers.
Workers must be in worker_wait() loop.
Parameters
----------
likelihood : CobayaLikelihood
Cobaya likelihood object with _get_single_valid_point method.
n_points : int
Number of initial points to generate.
rng : np.random.Generator, optional
Random number generator. Only used in serial mode.
Returns
-------
list of tuple
List of (point, logpost) tuples for master process, None for workers.
"""
if not self.is_main_process:
return None
if not self.is_mpi:
# Serial execution
rng = np.random.default_rng() if rng is None else rng
return [likelihood._get_single_valid_point(rng) for _ in range(n_points)]
# The payload for this task is trivial; we just need to send n_points signals
tasks = [None] * n_points
if self.dynamic_dispatch:
results_tuples = self._dynamic_distribute(tasks, self.TASK_COBAYA_INIT)
else:
master_rng = get_numpy_rng()
results_tuples = self._static_distribute(
tasks, self.TASK_COBAYA_INIT,
local_fn=lambda _: likelihood._get_single_valid_point(master_rng)
)
return results_tuples
[docs]
def clear_jax_caches(self):
"""Clear JAX caches on all processes."""
jax.clear_caches()
if self.is_mpi and self.is_main_process:
for rank in range(1, self.size):
self.comm.send((self.TASK_CLEAR_JAX_CACHES, None), dest=rank)
[docs]
def close(self):
"""
Shut down the pool by telling all workers to exit.
Sends TASK_EXIT signal to all worker processes, allowing them to
exit from the worker_wait() loop gracefully.
"""
if self.is_worker:
return
if self.is_mpi and self.size > 1:
log.info("Sending exit signal to all workers")
for i in range(1, self.size):
self.comm.send((self.TASK_EXIT, None), dest=i)