Note
Click here to download the full example code
Recall control for semantic segmentation¶
This example illustrates how to control the recall of a semantic segmentation model using MAPIE.
We use SemanticSegmentationController
to calibrate a decision threshold that statistically guarantees
a target recall level on unseen data.
The dataset, model and utility functions are loaded from Hugging Face for simplicity and reproducibility.
import importlib.util
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from mapie.risk_control import SemanticSegmentationController
warnings.filterwarnings("ignore")
To keep this example self-contained, we load the dataset utilities and the segmentation LightningModule definition directly from a repository hosted on Hugging Face.
module_path = hf_hub_download(
repo_id="mapie-library/rooftop_segmentation",
filename="model_and_lightning_module.py",
repo_type="dataset",
)
spec = importlib.util.spec_from_file_location("hf_module", module_path)
hf_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(hf_module)
SegmentationLightningModule = hf_module.SegmentationLightningModule
RoofSegmentationDataset = hf_module.RoofSegmentationDataset
get_validation_transforms = hf_module.get_validation_transforms
Load a pretrained segmentation model checkpoint from Hugging Face.
model_ckpt = hf_hub_download(
repo_id="mapie-library/rooftop_segmentation",
filename="best_model-v1.ckpt",
repo_type="dataset",
)
data_root = Path(
snapshot_download(
repo_id="mapie-library/rooftop_segmentation",
repo_type="dataset",
allow_patterns=["calib/**", "test/**"],
)
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SegmentationLightningModule.load_from_checkpoint(model_ckpt)
model.to(DEVICE)
model.eval()
print("Model loaded successfully!")
Out:
Fetching ... files: 0it [00:00, ?it/s]
Fetching ... files: 1154it [00:00, 399358.65it/s]
Model loaded successfully!
Next, two datasets are loaded from Hugging Face: a calibration set used to estimate risks and select an appropriate decision threshold, and a test set reserved for evaluating controlled predictions on unseen data.
CALIB_IMAGES_DIR = data_root / "calib" / "images"
CALIB_MASKS_DIR = data_root / "calib" / "masks"
TEST_IMAGES_DIR = data_root / "test" / "images"
TEST_MASKS_DIR = data_root / "test" / "masks"
calib_dataset = RoofSegmentationDataset(
images_dir=CALIB_IMAGES_DIR,
masks_dir=CALIB_MASKS_DIR,
transform=get_validation_transforms(
image_size=(256, 256)
), # reshape images to reduce memory usage
)
calib_loader = torch.utils.data.DataLoader(calib_dataset, batch_size=16)
test_dataset = RoofSegmentationDataset(
images_dir=TEST_IMAGES_DIR,
masks_dir=TEST_MASKS_DIR,
transform=get_validation_transforms(
image_size=(256, 256)
), # reshape images to reduce memory usage
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16)
print(f"Calibration set size: {len(calib_dataset)}")
print(f"Test set size: {len(test_dataset)}")
Out:
Dataset initialized with 289 image-mask pairs
Dataset initialized with 288 image-mask pairs
Calibration set size: 289
Test set size: 288
A SemanticSegmentationController is instantiated
to control the recall risk (1 - recall) and automatically select a threshold
that meets the target recall level with high confidence.
TARGET_RECALL = 0.9
recall_controller = SemanticSegmentationController(
predict_function=model,
risk="recall",
target_level=TARGET_RECALL,
confidence_level=None,
)
print(f"Target recall level: {TARGET_RECALL}")
Out:
During calibration, the controller evaluates the recall risk over a range of thresholds on the calibration dataset in order to identify an optimal decision threshold.
for i, sample in enumerate(calib_loader):
image, mask = sample["image"], sample["mask"]
image = image.to(DEVICE)
mask = mask.cpu().numpy()
# Filter images that contain masks
has_mask = mask.sum(axis=(1, 2)) > 0
image = image[has_mask]
mask = mask[has_mask]
if len(image) > 0:
with torch.no_grad():
recall_controller.compute_risks(image, mask)
# Compute the best threshold
recall_controller.compute_best_predict_param()
print("Controller calibrated successfully!")
print(f"Optimal threshold found: {recall_controller.best_predict_param[0]:.4f}")
Out:
Controlled predictions are visually inspected on a few test images to illustrate the effect of MAPIE thresholding compared to raw model outputs.
def denormalize_image(tensor_image: torch.Tensor) -> np.ndarray:
"""Denormalize image tensor for visualization."""
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = tensor_image.cpu().numpy().transpose(1, 2, 0)
image = std * image + mean
image = np.clip(image, 0, 1)
return image
# Select random test images
NUM_EXAMPLES = 4
np.random.seed(0)
# Get indices of images with masks
indices_with_masks = []
for idx in range(len(test_dataset)):
sample = test_dataset[idx]
if sample["mask"].sum() > 0:
indices_with_masks.append(idx)
random_indices = np.random.choice(indices_with_masks, NUM_EXAMPLES, replace=False)
fig, axes = plt.subplots(2, NUM_EXAMPLES, figsize=(4 * NUM_EXAMPLES, 10))
for col, idx in enumerate(random_indices):
sample = test_dataset[idx]
image = sample["image"].unsqueeze(0).to(DEVICE)
mask = sample["mask"].cpu().numpy()
with torch.no_grad():
# Get MAPIE prediction
mapie_pred = recall_controller.predict(image)[0]
# Denormalize image
img_display = denormalize_image(sample["image"])
# Plot original image (top row)
axes[0, col].imshow(img_display)
axes[0, col].set_title("Original Image")
axes[0, col].axis("off")
# Plot MAPIE prediction with correct pixels in white and false negatives in red (bottom row)
pred_visualization = np.zeros((*mapie_pred[0].shape, 3))
true_positives = mask * mapie_pred[0]
pred_visualization[true_positives > 0] = [1, 1, 1]
false_negatives = mask * (1 - mapie_pred[0])
pred_visualization[false_negatives > 0] = [1, 0, 0]
axes[1, col].imshow(pred_visualization)
axes[1, col].set_title(
f"MAPIE Prediction (threshold={recall_controller.best_predict_param[0]:.2f})\n"
"White: Correct | Red: False Negatives"
)
axes[1, col].axis("off")
plt.tight_layout()
plt.show()

The controller is finally evaluated on the test set by computing the achieved recall on each image to verify that the target recall level is satisfied on unseen data.
recalls_list = []
for i, sample in enumerate(test_loader):
image, mask = sample["image"], sample["mask"]
image = image.to(DEVICE)
mask = mask.cpu().numpy()
# Filter images with masks
has_mask = mask.sum(axis=(1, 2)) > 0
image = image[has_mask]
mask = mask[has_mask]
if len(image) > 0:
with torch.no_grad():
pred = recall_controller.predict(image)
# Compute recall for each image
for j in range(len(image)):
tp = (mask[j] * pred[j]).sum()
fn = (mask[j] * (1 - pred[j])).sum()
recall = tp / (tp + fn + 1e-8)
recalls_list.append(recall)
recalls_array = np.array(recalls_list)
Finally, the distribution of recall values over the test set is plotted to summarize the controlled performance.
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(recalls_array, bins=30, alpha=0.7, color="steelblue", edgecolor="black")
ax.axvline(
TARGET_RECALL,
color="red",
linestyle="--",
linewidth=2,
label=f"Target Recall ({TARGET_RECALL})",
)
ax.axvline(
recalls_array.mean(),
color="orange",
linestyle="--",
linewidth=2,
label=f"Mean Recall ({recalls_array.mean():.3f})",
)
ax.set_xlabel("Recall", fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
ax.set_title("Distribution of Recall on Test Set", fontsize=14, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

The histogram shows that most test images achieve or exceed the target recall level, illustrating the effectiveness of MAPIE’s risk control for semantic segmentation tasks.
Bootstrap the mean recall over different samplings of the test set (resampling images with replacement).
N_BOOTSTRAP = 2000
BOOTSTRAP_SEED = 123
rng = np.random.default_rng(BOOTSTRAP_SEED)
bootstrap_means = np.empty(N_BOOTSTRAP, dtype=float)
n = recalls_array.size
for b in range(N_BOOTSTRAP):
bootstrap_sample = rng.choice(recalls_array, size=n, replace=True)
bootstrap_means[b] = bootstrap_sample.mean()
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(
bootstrap_means,
bins=40,
alpha=0.7,
color="slateblue",
edgecolor="black",
)
ax.axvline(
bootstrap_means.mean(),
color="green",
linestyle="--",
linewidth=2,
label=f"Bootstrap mean ({bootstrap_means.mean():.3f})",
)
ax.axvline(
TARGET_RECALL,
color="orange",
linestyle="--",
linewidth=2,
label=f"Target recall ({TARGET_RECALL:.2f})",
)
ax.set_xlabel("Bootstrap mean recall", fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
ax.set_title(
"Bootstrap distribution of mean recall (test set resampling)",
fontsize=14,
fontweight="bold",
)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Total running time of the script: ( 0 minutes 42.104 seconds)
Download Python source code: plot_semantic_segmentation_recall_control.py
Download Jupyter notebook: plot_semantic_segmentation_recall_control.ipynb