Note
Click here to download the full example code
Precision control for semantic segmentation¶
This example illustrates how to control the precision of a semantic segmentation model using MAPIE.
We use SemanticSegmentationController
to calibrate a decision threshold that statistically guarantees
a target precision level on unseen data.
The dataset, model and utility functions are loaded from Hugging Face for simplicity and reproducibility.
import importlib.util
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from mapie.risk_control import SemanticSegmentationController
warnings.filterwarnings("ignore")
To keep this example self-contained, we load the dataset utilities and the segmentation LightningModule definition directly from a repository hosted on Hugging Face.
module_path = hf_hub_download(
repo_id="mapie-library/rooftop_segmentation",
filename="model_and_lightning_module.py",
repo_type="dataset",
)
spec = importlib.util.spec_from_file_location("hf_module", module_path)
hf_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(hf_module)
SegmentationLightningModule = hf_module.SegmentationLightningModule
RoofSegmentationDataset = hf_module.RoofSegmentationDataset
get_validation_transforms = hf_module.get_validation_transforms
Out:
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Load a pretrained segmentation model checkpoint from Hugging Face.
model_ckpt = hf_hub_download(
repo_id="mapie-library/rooftop_segmentation",
filename="best_model-v1.ckpt",
repo_type="dataset",
)
data_root = Path(
snapshot_download(
repo_id="mapie-library/rooftop_segmentation",
repo_type="dataset",
allow_patterns=["calib/**", "test/**"],
)
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SegmentationLightningModule.load_from_checkpoint(model_ckpt)
model.to(DEVICE)
model.eval()
print("Model loaded successfully!")
Out:
Fetching ... files: 0it [00:00, ?it/s]
Fetching ... files: 3it [00:00, 27.20it/s]
Fetching ... files: 9it [00:00, 29.22it/s]
Fetching ... files: 17it [00:00, 32.63it/s]
Fetching ... files: 25it [00:00, 33.61it/s]
Fetching ... files: 33it [00:00, 34.10it/s]
Fetching ... files: 41it [00:01, 34.25it/s]
Fetching ... files: 45it [00:01, 28.56it/s]
Fetching ... files: 55it [00:01, 33.55it/s]
Fetching ... files: 63it [00:01, 33.92it/s]
Fetching ... files: 71it [00:02, 34.16it/s]
Fetching ... files: 79it [00:02, 34.19it/s]
Fetching ... files: 87it [00:02, 34.32it/s]
Fetching ... files: 95it [00:02, 34.16it/s]
Fetching ... files: 101it [00:03, 30.06it/s]
Fetching ... files: 105it [00:03, 28.21it/s]
Fetching ... files: 109it [00:03, 23.88it/s]
Fetching ... files: 113it [00:03, 24.49it/s]
Fetching ... files: 119it [00:03, 23.33it/s]
Fetching ... files: 132it [00:04, 35.61it/s]
Fetching ... files: 138it [00:04, 39.45it/s]
Fetching ... files: 143it [00:04, 36.10it/s]
Fetching ... files: 148it [00:04, 33.85it/s]
Fetching ... files: 154it [00:04, 38.80it/s]
Fetching ... files: 159it [00:04, 35.16it/s]
Fetching ... files: 164it [00:05, 33.09it/s]
Fetching ... files: 170it [00:05, 38.16it/s]
Fetching ... files: 175it [00:05, 35.04it/s]
Fetching ... files: 180it [00:05, 32.63it/s]
Fetching ... files: 186it [00:05, 37.52it/s]
Fetching ... files: 191it [00:05, 33.30it/s]
Fetching ... files: 196it [00:05, 33.10it/s]
Fetching ... files: 202it [00:06, 26.71it/s]
Fetching ... files: 206it [00:06, 26.16it/s]
Fetching ... files: 216it [00:06, 32.91it/s]
Fetching ... files: 224it [00:06, 33.43it/s]
Fetching ... files: 228it [00:07, 26.67it/s]
Fetching ... files: 239it [00:07, 34.37it/s]
Fetching ... files: 247it [00:07, 34.40it/s]
Fetching ... files: 255it [00:07, 34.56it/s]
Fetching ... files: 263it [00:08, 34.72it/s]
Fetching ... files: 271it [00:08, 34.80it/s]
Fetching ... files: 279it [00:08, 34.83it/s]
Fetching ... files: 283it [00:08, 34.99it/s]
Fetching ... files: 287it [00:08, 34.83it/s]
Fetching ... files: 293it [00:08, 39.87it/s]
Fetching ... files: 298it [00:09, 35.38it/s]
Fetching ... files: 303it [00:09, 32.77it/s]
Fetching ... files: 309it [00:09, 36.22it/s]
Fetching ... files: 313it [00:09, 33.64it/s]
Fetching ... files: 319it [00:09, 32.78it/s]
Fetching ... files: 325it [00:09, 37.36it/s]
Fetching ... files: 329it [00:09, 33.17it/s]
Fetching ... files: 335it [00:10, 32.45it/s]
Fetching ... files: 341it [00:10, 37.09it/s]
Fetching ... files: 345it [00:10, 33.65it/s]
Fetching ... files: 351it [00:10, 33.09it/s]
Fetching ... files: 358it [00:10, 37.84it/s]
Fetching ... files: 362it [00:10, 35.79it/s]
Fetching ... files: 371it [00:11, 39.88it/s]
Fetching ... files: 375it [00:11, 37.69it/s]
Fetching ... files: 379it [00:11, 37.98it/s]
Fetching ... files: 383it [00:11, 36.88it/s]
Fetching ... files: 387it [00:11, 36.74it/s]
Fetching ... files: 391it [00:11, 35.85it/s]
Fetching ... files: 396it [00:11, 36.30it/s]
Fetching ... files: 403it [00:11, 43.72it/s]
Fetching ... files: 408it [00:12, 37.92it/s]
Fetching ... files: 416it [00:12, 40.38it/s]
Fetching ... files: 421it [00:12, 40.47it/s]
Fetching ... files: 426it [00:12, 37.75it/s]
Fetching ... files: 432it [00:12, 36.16it/s]
Fetching ... files: 438it [00:12, 38.45it/s]
Fetching ... files: 442it [00:12, 37.19it/s]
Fetching ... files: 446it [00:13, 37.07it/s]
Fetching ... files: 450it [00:13, 36.25it/s]
Fetching ... files: 454it [00:13, 35.93it/s]
Fetching ... files: 458it [00:13, 34.75it/s]
Fetching ... files: 462it [00:13, 34.69it/s]
Fetching ... files: 466it [00:13, 34.94it/s]
Fetching ... files: 470it [00:13, 34.69it/s]
Fetching ... files: 474it [00:13, 34.99it/s]
Fetching ... files: 478it [00:14, 35.23it/s]
Fetching ... files: 482it [00:14, 32.15it/s]
Fetching ... files: 489it [00:14, 33.89it/s]
Fetching ... files: 494it [00:14, 35.98it/s]
Fetching ... files: 498it [00:14, 33.77it/s]
Fetching ... files: 502it [00:14, 25.55it/s]
Fetching ... files: 505it [00:15, 24.56it/s]
Fetching ... files: 517it [00:15, 39.97it/s]
Fetching ... files: 522it [00:15, 38.38it/s]
Fetching ... files: 527it [00:15, 40.78it/s]
Fetching ... files: 532it [00:15, 38.71it/s]
Fetching ... files: 537it [00:15, 37.74it/s]
Fetching ... files: 541it [00:15, 38.10it/s]
Fetching ... files: 545it [00:15, 36.14it/s]
Fetching ... files: 549it [00:16, 34.18it/s]
Fetching ... files: 554it [00:16, 31.78it/s]
Fetching ... files: 561it [00:16, 36.70it/s]
Fetching ... files: 565it [00:16, 34.67it/s]
Fetching ... files: 570it [00:16, 31.79it/s]
Fetching ... files: 577it [00:16, 37.58it/s]
Fetching ... files: 581it [00:16, 36.50it/s]
Fetching ... files: 586it [00:17, 38.90it/s]
Fetching ... files: 590it [00:17, 35.93it/s]
Fetching ... files: 595it [00:17, 31.63it/s]
Fetching ... files: 601it [00:17, 27.12it/s]
Fetching ... files: 604it [00:17, 25.34it/s]
Fetching ... files: 612it [00:17, 35.68it/s]
Fetching ... files: 617it [00:18, 26.05it/s]
Fetching ... files: 623it [00:18, 23.88it/s]
Fetching ... files: 635it [00:18, 35.50it/s]
Fetching ... files: 640it [00:18, 34.88it/s]
Fetching ... files: 645it [00:19, 34.91it/s]
Fetching ... files: 651it [00:19, 35.33it/s]
Fetching ... files: 655it [00:19, 35.37it/s]
Fetching ... files: 659it [00:19, 35.04it/s]
Fetching ... files: 663it [00:19, 35.33it/s]
Fetching ... files: 667it [00:19, 34.61it/s]
Fetching ... files: 671it [00:19, 34.97it/s]
Fetching ... files: 675it [00:19, 34.51it/s]
Fetching ... files: 679it [00:19, 35.10it/s]
Fetching ... files: 683it [00:20, 34.42it/s]
Fetching ... files: 687it [00:20, 34.86it/s]
Fetching ... files: 691it [00:20, 34.42it/s]
Fetching ... files: 695it [00:20, 35.08it/s]
Fetching ... files: 699it [00:20, 34.28it/s]
Fetching ... files: 703it [00:20, 23.20it/s]
Fetching ... files: 710it [00:21, 29.35it/s]
Fetching ... files: 715it [00:21, 32.23it/s]
Fetching ... files: 719it [00:21, 32.08it/s]
Fetching ... files: 723it [00:21, 33.54it/s]
Fetching ... files: 727it [00:21, 33.24it/s]
Fetching ... files: 731it [00:21, 34.20it/s]
Fetching ... files: 735it [00:21, 33.68it/s]
Fetching ... files: 739it [00:21, 34.96it/s]
Fetching ... files: 743it [00:21, 33.91it/s]
Fetching ... files: 747it [00:22, 32.64it/s]
Fetching ... files: 752it [00:22, 35.33it/s]
Fetching ... files: 756it [00:22, 33.95it/s]
Fetching ... files: 760it [00:22, 35.49it/s]
Fetching ... files: 764it [00:22, 33.52it/s]
Fetching ... files: 768it [00:22, 23.46it/s]
Fetching ... files: 778it [00:23, 35.48it/s]
Fetching ... files: 784it [00:23, 38.12it/s]
Fetching ... files: 789it [00:23, 35.06it/s]
Fetching ... files: 794it [00:23, 34.18it/s]
Fetching ... files: 800it [00:23, 37.33it/s]
Fetching ... files: 804it [00:23, 25.32it/s]
Fetching ... files: 811it [00:24, 31.85it/s]
Fetching ... files: 815it [00:24, 32.83it/s]
Fetching ... files: 819it [00:24, 32.89it/s]
Fetching ... files: 823it [00:24, 33.53it/s]
Fetching ... files: 827it [00:24, 33.84it/s]
Fetching ... files: 831it [00:24, 33.77it/s]
Fetching ... files: 835it [00:24, 34.29it/s]
Fetching ... files: 839it [00:24, 34.30it/s]
Fetching ... files: 843it [00:24, 34.55it/s]
Fetching ... files: 847it [00:25, 34.50it/s]
Fetching ... files: 851it [00:25, 34.87it/s]
Fetching ... files: 855it [00:25, 34.51it/s]
Fetching ... files: 859it [00:25, 35.03it/s]
Fetching ... files: 863it [00:25, 34.35it/s]
Fetching ... files: 867it [00:25, 35.18it/s]
Fetching ... files: 871it [00:25, 34.21it/s]
Fetching ... files: 875it [00:25, 35.49it/s]
Fetching ... files: 879it [00:26, 34.34it/s]
Fetching ... files: 885it [00:26, 41.02it/s]
Fetching ... files: 890it [00:26, 35.91it/s]
Fetching ... files: 896it [00:26, 36.81it/s]
Fetching ... files: 900it [00:26, 36.42it/s]
Fetching ... files: 906it [00:26, 40.07it/s]
Fetching ... files: 911it [00:26, 41.46it/s]
Fetching ... files: 916it [00:26, 38.51it/s]
Fetching ... files: 921it [00:27, 39.01it/s]
Fetching ... files: 925it [00:27, 34.63it/s]
Fetching ... files: 931it [00:27, 38.54it/s]
Fetching ... files: 935it [00:27, 36.26it/s]
Fetching ... files: 939it [00:27, 37.12it/s]
Fetching ... files: 943it [00:27, 34.03it/s]
Fetching ... files: 950it [00:27, 38.68it/s]
Fetching ... files: 958it [00:27, 47.14it/s]
Fetching ... files: 964it [00:28, 48.20it/s]
Fetching ... files: 969it [00:28, 43.68it/s]
Fetching ... files: 975it [00:28, 47.23it/s]
Fetching ... files: 981it [00:28, 47.49it/s]
Fetching ... files: 986it [00:28, 42.08it/s]
Fetching ... files: 991it [00:28, 40.64it/s]
Fetching ... files: 996it [00:28, 38.80it/s]
Fetching ... files: 1000it [00:29, 38.71it/s]
Fetching ... files: 1004it [00:29, 29.23it/s]
Fetching ... files: 1008it [00:29, 30.22it/s]
Fetching ... files: 1012it [00:29, 30.75it/s]
Fetching ... files: 1016it [00:29, 20.42it/s]
Fetching ... files: 1026it [00:30, 25.45it/s]
Fetching ... files: 1037it [00:30, 35.86it/s]
Fetching ... files: 1042it [00:30, 35.78it/s]
Fetching ... files: 1047it [00:30, 31.44it/s]
Fetching ... files: 1053it [00:30, 36.27it/s]
Fetching ... files: 1058it [00:31, 26.54it/s]
Fetching ... files: 1067it [00:31, 35.71it/s]
Fetching ... files: 1072it [00:31, 36.83it/s]
Fetching ... files: 1077it [00:31, 25.44it/s]
Fetching ... files: 1086it [00:32, 26.93it/s]
Fetching ... files: 1096it [00:32, 28.33it/s]
Fetching ... files: 1106it [00:32, 35.75it/s]
Fetching ... files: 1111it [00:32, 29.70it/s]
Fetching ... files: 1115it [00:32, 29.12it/s]
Fetching ... files: 1122it [00:33, 35.60it/s]
Fetching ... files: 1127it [00:33, 36.70it/s]
Fetching ... files: 1132it [00:33, 33.03it/s]
Fetching ... files: 1136it [00:33, 27.20it/s]
Fetching ... files: 1142it [00:33, 32.93it/s]
Fetching ... files: 1146it [00:33, 23.47it/s]
Fetching ... files: 1154it [00:34, 32.33it/s]
Fetching ... files: 1154it [00:34, 33.83it/s]
Model loaded successfully!
Next, two datasets are loaded from Hugging Face: a calibration set used to estimate risks and select an appropriate decision threshold, and a test set reserved for evaluating controlled predictions on unseen data.
CALIB_IMAGES_DIR = data_root / "calib" / "images"
CALIB_MASKS_DIR = data_root / "calib" / "masks"
TEST_IMAGES_DIR = data_root / "test" / "images"
TEST_MASKS_DIR = data_root / "test" / "masks"
calib_dataset = RoofSegmentationDataset(
images_dir=CALIB_IMAGES_DIR,
masks_dir=CALIB_MASKS_DIR,
transform=get_validation_transforms(
image_size=(256, 256)
), # reshape images to reduce memory usage
)
calib_loader = torch.utils.data.DataLoader(calib_dataset, batch_size=16)
test_dataset = RoofSegmentationDataset(
images_dir=TEST_IMAGES_DIR,
masks_dir=TEST_MASKS_DIR,
transform=get_validation_transforms(
image_size=(256, 256)
), # reshape images to reduce memory usage
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16)
print(f"Calibration set size: {len(calib_dataset)}")
print(f"Test set size: {len(test_dataset)}")
Out:
Dataset initialized with 289 image-mask pairs
Dataset initialized with 288 image-mask pairs
Calibration set size: 289
Test set size: 288
A SemanticSegmentationController is instantiated
to control the precision risk (1 - precision) and automatically select a threshold
that meets the target precision level with high confidence.
TARGET_PRECISION = 0.7
CONFIDENCE_LEVEL = 0.9
precision_controller = SemanticSegmentationController(
predict_function=model,
risk="precision",
target_level=TARGET_PRECISION,
confidence_level=CONFIDENCE_LEVEL,
)
print(f"Target precision level: {TARGET_PRECISION}")
Out:
During calibration, the controller evaluates the precision risk over a range of thresholds on the calibration dataset in order to identify an optimal decision threshold.
for i, sample in enumerate(calib_loader):
image, mask = sample["image"], sample["mask"]
image = image.to(DEVICE)
mask = mask.cpu().numpy()
# Filter images that contain masks
has_mask = mask.sum(axis=(1, 2)) > 0
image = image[has_mask]
mask = mask[has_mask]
if len(image) > 0:
with torch.no_grad():
precision_controller.compute_risks(image, mask)
# Compute the best threshold
precision_controller.compute_best_predict_param()
print("Controller calibrated successfully!")
print(f"Optimal threshold found: {precision_controller.best_predict_param[0]:.4f}")
Out:
Controlled predictions are visually inspected on a few test images to illustrate the effect of MAPIE thresholding compared to raw model outputs.
def denormalize_image(tensor_image: torch.Tensor) -> np.ndarray:
"""Denormalize image tensor for visualization."""
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = tensor_image.cpu().numpy().transpose(1, 2, 0)
image = std * image + mean
image = np.clip(image, 0, 1)
return image
# Select random test images
NUM_EXAMPLES = 4
np.random.seed(0)
# Get indices of images with masks
indices_with_masks = []
for idx in range(len(test_dataset)):
sample = test_dataset[idx]
if sample["mask"].sum() > 0:
indices_with_masks.append(idx)
random_indices = np.random.choice(indices_with_masks, NUM_EXAMPLES, replace=False)
fig, axes = plt.subplots(2, NUM_EXAMPLES, figsize=(4 * NUM_EXAMPLES, 10))
for col, idx in enumerate(random_indices):
sample = test_dataset[idx]
image = sample["image"].unsqueeze(0).to(DEVICE)
mask = sample["mask"].cpu().numpy()
with torch.no_grad():
# Get MAPIE prediction
mapie_pred = precision_controller.predict(image)[0]
# Denormalize image
img_display = denormalize_image(sample["image"])
# Plot original image (top row)
axes[0, col].imshow(img_display)
axes[0, col].set_title("Original Image")
axes[0, col].axis("off")
# Plot MAPIE prediction with correct pixels in white and false positives in red (bottom row)
pred_visualization = np.zeros((*mapie_pred[0].shape, 3))
true_positives = mask * mapie_pred[0]
pred_visualization[true_positives > 0] = [1, 1, 1]
false_positives = (1 - mask) * mapie_pred[0]
pred_visualization[false_positives > 0] = [1, 0, 0]
axes[1, col].imshow(pred_visualization)
axes[1, col].set_title(
f"MAPIE Prediction (threshold={precision_controller.best_predict_param[0]:.2f})\n"
"White: Correct | Red: False Positives"
)
axes[1, col].axis("off")
plt.tight_layout()
plt.show()

The controller is finally evaluated on the test set by computing the achieved precision on each image to verify that the target precision level is satisfied on unseen data.
precisions_list = []
for i, sample in enumerate(test_loader):
image, mask = sample["image"], sample["mask"]
image = image.to(DEVICE)
mask = mask.cpu().numpy()
# Filter images with masks
has_mask = mask.sum(axis=(1, 2)) > 0
image = image[has_mask]
mask = mask[has_mask]
if len(image) > 0:
with torch.no_grad():
pred = precision_controller.predict(image)
# Compute precision for each image
for j in range(len(image)):
tp = (mask[j] * pred[j]).sum()
fp = ((1 - mask[j]) * pred[j]).sum()
precision = tp / (tp + fp + 1e-8)
precisions_list.append(precision)
precisions_array = np.array(precisions_list)
Finally, the distribution of precision values over the test set is plotted to summarize the controlled performance.
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(precisions_array, bins=30, alpha=0.7, color="steelblue", edgecolor="black")
ax.axvline(
TARGET_PRECISION,
color="red",
linestyle="--",
linewidth=2,
label=f"Target Precision ({TARGET_PRECISION})",
)
ax.axvline(
precisions_array.mean(),
color="orange",
linestyle="--",
linewidth=2,
label=f"Mean Precision ({precisions_array.mean():.3f})",
)
ax.set_xlabel("Precision", fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
ax.set_title("Distribution of Precision on Test Set", fontsize=14, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

The histogram shows that most test images achieve or exceed the target precision level, illustrating the effectiveness of MAPIE’s risk control for semantic segmentation tasks.
Bootstrap the mean precision over different samplings of the test set (resampling images with replacement).
N_BOOTSTRAP = 2000
BOOTSTRAP_SEED = 123
rng = np.random.default_rng(BOOTSTRAP_SEED)
bootstrap_means = np.empty(N_BOOTSTRAP, dtype=float)
n = precisions_array.size
for b in range(N_BOOTSTRAP):
bootstrap_sample = rng.choice(precisions_array, size=n, replace=True)
bootstrap_means[b] = bootstrap_sample.mean()
delta = round(1 - CONFIDENCE_LEVEL, 2)
quantile_confidence = np.quantile(bootstrap_means, delta)
print(f"Bootstrap {delta}-th quantile: {quantile_confidence:.4f}")
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(
bootstrap_means,
bins=40,
alpha=0.7,
color="slateblue",
edgecolor="black",
)
ax.axvline(
TARGET_PRECISION,
color="orange",
linestyle="--",
linewidth=2,
label=f"Target precision ({TARGET_PRECISION:.2f})",
)
ax.axvline(
quantile_confidence,
color="green",
linestyle="--",
linewidth=2,
label=f"Bootstrap {delta}-th quantile ({quantile_confidence:.3f})",
)
ax.set_xlabel("Bootstrap mean precision", fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
ax.set_title(
"Bootstrap distribution of mean precision (test set resampling)",
fontsize=14,
fontweight="bold",
)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Out:
Total running time of the script: ( 1 minutes 23.477 seconds)
Download Python source code: plot_semantic_segmentation_precision_control.py
Download Jupyter notebook: plot_semantic_segmentation_precision_control.ipynb