User Guide#
This guide focuses on the pieces that are specific to bde. If you are new to
scikit-learn’s estimator API, refer to the official developer guide for the foundational
concepts. The sections below assume that background and concentrate on how
BdeRegressor and BdeClassifier behave, how they integrate with JAX, and how
you should prepare data to get reliable results.
For installation, environment setup, and JAX device configuration, start with
Quick start.
Method#
bde implements Bayesian Deep Ensembles using Microcanonical Langevin Ensembles (MILE),
a hybrid approach for Bayesian neural network inference that combines deterministic
ensemble pre-training with efficient MCMC sampling.
Concretely, ensemble members are first trained independently using standard optimization, providing diverse and well-initialized starting points. These models are then refined using Microcanonical Langevin Monte Carlo (MCLMC) to generate high-quality posterior samples. This design yields strong predictive performance and reliable uncertainty estimates while remaining computationally efficient.
The method is particularly well-suited for the complex, multi-modal posteriors encountered in neural networks and can be implemented embarrassingly parallel across multiple devices.
Note that this package currently supports fully connected feedforward networks only and is targeted at tabular data tasks. The method can however also be applied to other architectures and data modalities, but these are not yet in the scope of this particular implementation. Further the full package acts in the full-batch setting only; stochastic mini-batching is not supported at this time.
For references and theoretical as well as algorithmic details, see Microcanonical Langevin Ensembles: Advancing the Sampling of Bayesian Neural Networks (ICLR 2025).
Estimator overview#
bde exposes two scikit-learn compatible estimators:
bde.BdeRegressorfor continuous targets.bde.BdeClassifierfor categorical targets.
Both inherit sklearn.base.BaseEstimator and the relevant mixins, so they
support the familiar fit/predict/score methods, accept keyword
hyperparameters in __init__, and can be dropped into a
sklearn.pipeline.Pipeline. Under the hood they train a fully connected
ensemble in JAX and then run an MCMC sampler to draw posterior weight samples. At
prediction time the estimator combines those samples to provide means, standard
deviations, credible intervals, probability vectors, or the raw ensemble outputs.
You can customise the architecture and training stack: choose any activation
function, swap in your own optimiser, or rely on the defaults (optax.adamw).
Losses also default sensibly: regression uses bde.loss.GaussianNLL and
classification uses bde.loss.CategoricalCrossEntropy.
Data preparation#
Bayesian deep ensembles are sensitive to feature and target scale because the
networks are initialised with zero-mean weights and the prior assumes unit-scale
activations. Large raw targets (for instance the default output of
sklearn.datasets.make_regression()) can lead to very poor fits if left
unscaled. Always apply basic preprocessing before calling fit:
Standardise features with
sklearn.preprocessing.StandardScaler(or an equivalent transformer) so each column has roughly zero mean and unit variance.For regression, standardise the target as well and keep the scaler handy if you need to transform predictions back to the original scale.
Gaussian likelihood (regression)#
Regression heads emit a mean and an unconstrained scale. The scale is mapped to a
positive standard deviation with softplus (plus a small epsilon) in all stages:
the training loss bde.loss.GaussianNLL, the
posterior log-likelihood in bde.sampler.probabilistic.ProbabilisticModel.log_likelihood(),
and the prediction helpers in bde.bde_evaluator.BdePredictor._regression_mu_sigma().
Note
If you request raw=True from the regressor you receive the unconstrained scale
head and should apply the same softplus transform before treating it as a
standard deviation.
Understanding the outputs#
The estimators expose several prediction modes:
predict(X)Returns the mean prediction (regression) or hard labels (classification).
predict(X, mean_and_std=True)Regression only; returns a tuple
(mean, std)wherestdcombines aleatoric and epistemic components.predict(X, credible_intervals=[0.05, 0.95])Regression only; returns
(mean, quantiles)where each quantile is computed from Monte Carlo samples drawn from every posterior component (i.e. the full mixture across ensemble members and MCMC draws). This reflects the predictive distribution of the entire ensemble rather than just parameter quantiles. For small posterior sample counts (n_samples < 10) a small random draw is used; for very large counts (n_samples > 10_000) a single sample is taken to keep the computation cheap.predict(X, raw=True)Returns the raw tensor with leading axes
(ensemble_members, samples, n, output_dims). Useful for custom diagnostics.predict_proba(X)Classification only; returns class probability vectors.
How to read uncertainties#
Mean + std (
mean_and_std=True):stdis the total predictive standard deviation. It sums aleatoric variance (averaged scale head) and epistemic variance (spread of ensemble means), so high values mean either noisy data or disagreement across members.Credible intervals (
credible_intervals=[...])): Quantiles are taken over samples from the full mixture of ensemble members and posterior draws. This captures both aleatoric and epistemic uncertainty. For example, requesting[0.05, 0.95]returns lower/upper curves you can treat as a 90% credible band.Raw outputs (
raw=True): Shape(E, T, N, D)for regression where,E=ensemble_members,T=n_samples,N=n_dataandD=2(mean, scale). You can manually compute aleatoric vs epistemic components, plot per-member predictions, or customise intervals if needed.
Key hyperparameters#
Model architecture
n_membersNumber of deterministic networks in the ensemble. Increasing members improves epistemic uncertainty estimation but raises computational cost (if enough parallel devices are available training time is not affected).
hidden_layersWidths of hidden layers. Defaults internally to
[4, 4]ifNone.
Pre-sampling optimization
epochs/patienceControl how long the deterministic pre-training runs before sampling.
epochsis the hard cap;patiencetriggers early stopping when the validation loss plateaus so the sampler starts from a high-likelihood region. WhenpatienceisNonetraining always runs for all epochs.
lrLearning rate for the Adam optimiser during pre-sampling training.
Sampling
warmup_steps/n_samples/n_thinningControl the MCMC sampling stage.
warmup_stepsadjusts the step size,n_samplesdefines the number of retained posterior draws, andn_thinningspecifies the interval between saved samples.
desired_energy_var_start/desired_energy_var_end/step_size_initConfigure the samplers behavior. The
desired_energy_var_*parameters set the target variance of the energy during sampling which is linearly annealed from start to end over the course of the warmup phase. Thestep_size_initparameter sets the initial step size for the dynamics integrator; this is adapted during warmup to reach the desired energy variance. For medium sized BNNs a good default is to setdesired_energy_var_start=0.5,desired_energy_var_end=0.1, and pick the learning rate as thestep_size_init(or slightly larger). For simpler models or highly over parameterized settings (for example a 2x16 network provides good results on a small dataset, then using a 3x32 network would be considered highly over parameterized) decreasing the desired energy variance targets might be necessary to reach good performance. The desired energy variance is the most important hyperparameter to tune for sampler performance.
prior_familyIsotropic weight prior used for all ensemble members. Accepts string keys or
bde.sampler.prior.PriorDistenums. Three families are supported:standardnormal(unit-variance Gaussian, and the default when unspecified),normal(Gaussian with configurableloc/scale), andlaplace. Combine withprior_kwargsto adjust the distribution; for example{"scale": 0.1}narrows the prior and the initialisation around zero.Note
The
standardnormalshortcut always usesloc=0andscale=1and ignoresprior_kwargs. Picknormalif you want a Gaussian prior with a different variance.
Sampler and builder internals#
After the deterministic training phase BdeRegressor and BdeClassifier
construct a bde.bde_builder.BdeBuilder instance. This helper manages the
ensemble members, coordinates parallel training across devices, and hands off to
bde.sampler utilities for warmup and sampling. Advanced users can interact
with these pieces directly:
estimator._bdereferences the builder afterfitand exposes the deterministic members and training history.estimator.positions_eT_stores the weight samples with shape(E, T, ...).
Generally you should rely on the high-level estimator API, but the internals are accessible for custom diagnostics or research experiments.
Where to next#
The Quick start page shows condensed scripts you can run end to end.
API Reference documents every public class and helper in the package.
Examples renders notebooks and plots that mirror the examples in the
examples/directory.