Skip to content

Note

Click here to download the full example code

Precision control for semantic segmentation

This example illustrates how to control the precision of a semantic segmentation model using MAPIE.

We use SemanticSegmentationController to calibrate a decision threshold that statistically guarantees a target precision level on unseen data.

The dataset, model and utility functions are loaded from Hugging Face for simplicity and reproducibility.

import importlib.util
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download

from mapie.risk_control import SemanticSegmentationController

warnings.filterwarnings("ignore")

To keep this example self-contained, we load the dataset utilities and the segmentation LightningModule definition directly from a repository hosted on Hugging Face.

module_path = hf_hub_download(
    repo_id="mapie-library/rooftop_segmentation",
    filename="model_and_lightning_module.py",
    repo_type="dataset",
)
spec = importlib.util.spec_from_file_location("hf_module", module_path)
hf_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(hf_module)

SegmentationLightningModule = hf_module.SegmentationLightningModule
RoofSegmentationDataset = hf_module.RoofSegmentationDataset
get_validation_transforms = hf_module.get_validation_transforms

Out:

Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.

Load a pretrained segmentation model checkpoint from Hugging Face.

model_ckpt = hf_hub_download(
    repo_id="mapie-library/rooftop_segmentation",
    filename="best_model-v1.ckpt",
    repo_type="dataset",
)

data_root = Path(
    snapshot_download(
        repo_id="mapie-library/rooftop_segmentation",
        repo_type="dataset",
        allow_patterns=["calib/**", "test/**"],
    )
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = SegmentationLightningModule.load_from_checkpoint(model_ckpt)
model.to(DEVICE)
model.eval()
print("Model loaded successfully!")

Out:

Fetching ... files: 0it [00:00, ?it/s]

Fetching ... files: 3it [00:00, 27.20it/s]

Fetching ... files: 9it [00:00, 29.22it/s]

Fetching ... files: 17it [00:00, 32.63it/s]

Fetching ... files: 25it [00:00, 33.61it/s]

Fetching ... files: 33it [00:00, 34.10it/s]

Fetching ... files: 41it [00:01, 34.25it/s]

Fetching ... files: 45it [00:01, 28.56it/s]

Fetching ... files: 55it [00:01, 33.55it/s]

Fetching ... files: 63it [00:01, 33.92it/s]

Fetching ... files: 71it [00:02, 34.16it/s]

Fetching ... files: 79it [00:02, 34.19it/s]

Fetching ... files: 87it [00:02, 34.32it/s]

Fetching ... files: 95it [00:02, 34.16it/s]

Fetching ... files: 101it [00:03, 30.06it/s]

Fetching ... files: 105it [00:03, 28.21it/s]

Fetching ... files: 109it [00:03, 23.88it/s]

Fetching ... files: 113it [00:03, 24.49it/s]

Fetching ... files: 119it [00:03, 23.33it/s]

Fetching ... files: 132it [00:04, 35.61it/s]

Fetching ... files: 138it [00:04, 39.45it/s]

Fetching ... files: 143it [00:04, 36.10it/s]

Fetching ... files: 148it [00:04, 33.85it/s]

Fetching ... files: 154it [00:04, 38.80it/s]

Fetching ... files: 159it [00:04, 35.16it/s]

Fetching ... files: 164it [00:05, 33.09it/s]

Fetching ... files: 170it [00:05, 38.16it/s]

Fetching ... files: 175it [00:05, 35.04it/s]

Fetching ... files: 180it [00:05, 32.63it/s]

Fetching ... files: 186it [00:05, 37.52it/s]

Fetching ... files: 191it [00:05, 33.30it/s]

Fetching ... files: 196it [00:05, 33.10it/s]

Fetching ... files: 202it [00:06, 26.71it/s]

Fetching ... files: 206it [00:06, 26.16it/s]

Fetching ... files: 216it [00:06, 32.91it/s]

Fetching ... files: 224it [00:06, 33.43it/s]

Fetching ... files: 228it [00:07, 26.67it/s]

Fetching ... files: 239it [00:07, 34.37it/s]

Fetching ... files: 247it [00:07, 34.40it/s]

Fetching ... files: 255it [00:07, 34.56it/s]

Fetching ... files: 263it [00:08, 34.72it/s]

Fetching ... files: 271it [00:08, 34.80it/s]

Fetching ... files: 279it [00:08, 34.83it/s]

Fetching ... files: 283it [00:08, 34.99it/s]

Fetching ... files: 287it [00:08, 34.83it/s]

Fetching ... files: 293it [00:08, 39.87it/s]

Fetching ... files: 298it [00:09, 35.38it/s]

Fetching ... files: 303it [00:09, 32.77it/s]

Fetching ... files: 309it [00:09, 36.22it/s]

Fetching ... files: 313it [00:09, 33.64it/s]

Fetching ... files: 319it [00:09, 32.78it/s]

Fetching ... files: 325it [00:09, 37.36it/s]

Fetching ... files: 329it [00:09, 33.17it/s]

Fetching ... files: 335it [00:10, 32.45it/s]

Fetching ... files: 341it [00:10, 37.09it/s]

Fetching ... files: 345it [00:10, 33.65it/s]

Fetching ... files: 351it [00:10, 33.09it/s]

Fetching ... files: 358it [00:10, 37.84it/s]

Fetching ... files: 362it [00:10, 35.79it/s]

Fetching ... files: 371it [00:11, 39.88it/s]

Fetching ... files: 375it [00:11, 37.69it/s]

Fetching ... files: 379it [00:11, 37.98it/s]

Fetching ... files: 383it [00:11, 36.88it/s]

Fetching ... files: 387it [00:11, 36.74it/s]

Fetching ... files: 391it [00:11, 35.85it/s]

Fetching ... files: 396it [00:11, 36.30it/s]

Fetching ... files: 403it [00:11, 43.72it/s]

Fetching ... files: 408it [00:12, 37.92it/s]

Fetching ... files: 416it [00:12, 40.38it/s]

Fetching ... files: 421it [00:12, 40.47it/s]

Fetching ... files: 426it [00:12, 37.75it/s]

Fetching ... files: 432it [00:12, 36.16it/s]

Fetching ... files: 438it [00:12, 38.45it/s]

Fetching ... files: 442it [00:12, 37.19it/s]

Fetching ... files: 446it [00:13, 37.07it/s]

Fetching ... files: 450it [00:13, 36.25it/s]

Fetching ... files: 454it [00:13, 35.93it/s]

Fetching ... files: 458it [00:13, 34.75it/s]

Fetching ... files: 462it [00:13, 34.69it/s]

Fetching ... files: 466it [00:13, 34.94it/s]

Fetching ... files: 470it [00:13, 34.69it/s]

Fetching ... files: 474it [00:13, 34.99it/s]

Fetching ... files: 478it [00:14, 35.23it/s]

Fetching ... files: 482it [00:14, 32.15it/s]

Fetching ... files: 489it [00:14, 33.89it/s]

Fetching ... files: 494it [00:14, 35.98it/s]

Fetching ... files: 498it [00:14, 33.77it/s]

Fetching ... files: 502it [00:14, 25.55it/s]

Fetching ... files: 505it [00:15, 24.56it/s]

Fetching ... files: 517it [00:15, 39.97it/s]

Fetching ... files: 522it [00:15, 38.38it/s]

Fetching ... files: 527it [00:15, 40.78it/s]

Fetching ... files: 532it [00:15, 38.71it/s]

Fetching ... files: 537it [00:15, 37.74it/s]

Fetching ... files: 541it [00:15, 38.10it/s]

Fetching ... files: 545it [00:15, 36.14it/s]

Fetching ... files: 549it [00:16, 34.18it/s]

Fetching ... files: 554it [00:16, 31.78it/s]

Fetching ... files: 561it [00:16, 36.70it/s]

Fetching ... files: 565it [00:16, 34.67it/s]

Fetching ... files: 570it [00:16, 31.79it/s]

Fetching ... files: 577it [00:16, 37.58it/s]

Fetching ... files: 581it [00:16, 36.50it/s]

Fetching ... files: 586it [00:17, 38.90it/s]

Fetching ... files: 590it [00:17, 35.93it/s]

Fetching ... files: 595it [00:17, 31.63it/s]

Fetching ... files: 601it [00:17, 27.12it/s]

Fetching ... files: 604it [00:17, 25.34it/s]

Fetching ... files: 612it [00:17, 35.68it/s]

Fetching ... files: 617it [00:18, 26.05it/s]

Fetching ... files: 623it [00:18, 23.88it/s]

Fetching ... files: 635it [00:18, 35.50it/s]

Fetching ... files: 640it [00:18, 34.88it/s]

Fetching ... files: 645it [00:19, 34.91it/s]

Fetching ... files: 651it [00:19, 35.33it/s]

Fetching ... files: 655it [00:19, 35.37it/s]

Fetching ... files: 659it [00:19, 35.04it/s]

Fetching ... files: 663it [00:19, 35.33it/s]

Fetching ... files: 667it [00:19, 34.61it/s]

Fetching ... files: 671it [00:19, 34.97it/s]

Fetching ... files: 675it [00:19, 34.51it/s]

Fetching ... files: 679it [00:19, 35.10it/s]

Fetching ... files: 683it [00:20, 34.42it/s]

Fetching ... files: 687it [00:20, 34.86it/s]

Fetching ... files: 691it [00:20, 34.42it/s]

Fetching ... files: 695it [00:20, 35.08it/s]

Fetching ... files: 699it [00:20, 34.28it/s]

Fetching ... files: 703it [00:20, 23.20it/s]

Fetching ... files: 710it [00:21, 29.35it/s]

Fetching ... files: 715it [00:21, 32.23it/s]

Fetching ... files: 719it [00:21, 32.08it/s]

Fetching ... files: 723it [00:21, 33.54it/s]

Fetching ... files: 727it [00:21, 33.24it/s]

Fetching ... files: 731it [00:21, 34.20it/s]

Fetching ... files: 735it [00:21, 33.68it/s]

Fetching ... files: 739it [00:21, 34.96it/s]

Fetching ... files: 743it [00:21, 33.91it/s]

Fetching ... files: 747it [00:22, 32.64it/s]

Fetching ... files: 752it [00:22, 35.33it/s]

Fetching ... files: 756it [00:22, 33.95it/s]

Fetching ... files: 760it [00:22, 35.49it/s]

Fetching ... files: 764it [00:22, 33.52it/s]

Fetching ... files: 768it [00:22, 23.46it/s]

Fetching ... files: 778it [00:23, 35.48it/s]

Fetching ... files: 784it [00:23, 38.12it/s]

Fetching ... files: 789it [00:23, 35.06it/s]

Fetching ... files: 794it [00:23, 34.18it/s]

Fetching ... files: 800it [00:23, 37.33it/s]

Fetching ... files: 804it [00:23, 25.32it/s]

Fetching ... files: 811it [00:24, 31.85it/s]

Fetching ... files: 815it [00:24, 32.83it/s]

Fetching ... files: 819it [00:24, 32.89it/s]

Fetching ... files: 823it [00:24, 33.53it/s]

Fetching ... files: 827it [00:24, 33.84it/s]

Fetching ... files: 831it [00:24, 33.77it/s]

Fetching ... files: 835it [00:24, 34.29it/s]

Fetching ... files: 839it [00:24, 34.30it/s]

Fetching ... files: 843it [00:24, 34.55it/s]

Fetching ... files: 847it [00:25, 34.50it/s]

Fetching ... files: 851it [00:25, 34.87it/s]

Fetching ... files: 855it [00:25, 34.51it/s]

Fetching ... files: 859it [00:25, 35.03it/s]

Fetching ... files: 863it [00:25, 34.35it/s]

Fetching ... files: 867it [00:25, 35.18it/s]

Fetching ... files: 871it [00:25, 34.21it/s]

Fetching ... files: 875it [00:25, 35.49it/s]

Fetching ... files: 879it [00:26, 34.34it/s]

Fetching ... files: 885it [00:26, 41.02it/s]

Fetching ... files: 890it [00:26, 35.91it/s]

Fetching ... files: 896it [00:26, 36.81it/s]

Fetching ... files: 900it [00:26, 36.42it/s]

Fetching ... files: 906it [00:26, 40.07it/s]

Fetching ... files: 911it [00:26, 41.46it/s]

Fetching ... files: 916it [00:26, 38.51it/s]

Fetching ... files: 921it [00:27, 39.01it/s]

Fetching ... files: 925it [00:27, 34.63it/s]

Fetching ... files: 931it [00:27, 38.54it/s]

Fetching ... files: 935it [00:27, 36.26it/s]

Fetching ... files: 939it [00:27, 37.12it/s]

Fetching ... files: 943it [00:27, 34.03it/s]

Fetching ... files: 950it [00:27, 38.68it/s]

Fetching ... files: 958it [00:27, 47.14it/s]

Fetching ... files: 964it [00:28, 48.20it/s]

Fetching ... files: 969it [00:28, 43.68it/s]

Fetching ... files: 975it [00:28, 47.23it/s]

Fetching ... files: 981it [00:28, 47.49it/s]

Fetching ... files: 986it [00:28, 42.08it/s]

Fetching ... files: 991it [00:28, 40.64it/s]

Fetching ... files: 996it [00:28, 38.80it/s]

Fetching ... files: 1000it [00:29, 38.71it/s]

Fetching ... files: 1004it [00:29, 29.23it/s]

Fetching ... files: 1008it [00:29, 30.22it/s]

Fetching ... files: 1012it [00:29, 30.75it/s]

Fetching ... files: 1016it [00:29, 20.42it/s]

Fetching ... files: 1026it [00:30, 25.45it/s]

Fetching ... files: 1037it [00:30, 35.86it/s]

Fetching ... files: 1042it [00:30, 35.78it/s]

Fetching ... files: 1047it [00:30, 31.44it/s]

Fetching ... files: 1053it [00:30, 36.27it/s]

Fetching ... files: 1058it [00:31, 26.54it/s]

Fetching ... files: 1067it [00:31, 35.71it/s]

Fetching ... files: 1072it [00:31, 36.83it/s]

Fetching ... files: 1077it [00:31, 25.44it/s]

Fetching ... files: 1086it [00:32, 26.93it/s]

Fetching ... files: 1096it [00:32, 28.33it/s]

Fetching ... files: 1106it [00:32, 35.75it/s]

Fetching ... files: 1111it [00:32, 29.70it/s]

Fetching ... files: 1115it [00:32, 29.12it/s]

Fetching ... files: 1122it [00:33, 35.60it/s]

Fetching ... files: 1127it [00:33, 36.70it/s]

Fetching ... files: 1132it [00:33, 33.03it/s]

Fetching ... files: 1136it [00:33, 27.20it/s]

Fetching ... files: 1142it [00:33, 32.93it/s]

Fetching ... files: 1146it [00:33, 23.47it/s]

Fetching ... files: 1154it [00:34, 32.33it/s]
Fetching ... files: 1154it [00:34, 33.83it/s]
Model loaded successfully!

Next, two datasets are loaded from Hugging Face: a calibration set used to estimate risks and select an appropriate decision threshold, and a test set reserved for evaluating controlled predictions on unseen data.

CALIB_IMAGES_DIR = data_root / "calib" / "images"
CALIB_MASKS_DIR = data_root / "calib" / "masks"
TEST_IMAGES_DIR = data_root / "test" / "images"
TEST_MASKS_DIR = data_root / "test" / "masks"

calib_dataset = RoofSegmentationDataset(
    images_dir=CALIB_IMAGES_DIR,
    masks_dir=CALIB_MASKS_DIR,
    transform=get_validation_transforms(
        image_size=(256, 256)
    ),  # reshape images to reduce memory usage
)
calib_loader = torch.utils.data.DataLoader(calib_dataset, batch_size=16)

test_dataset = RoofSegmentationDataset(
    images_dir=TEST_IMAGES_DIR,
    masks_dir=TEST_MASKS_DIR,
    transform=get_validation_transforms(
        image_size=(256, 256)
    ),  # reshape images to reduce memory usage
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16)

print(f"Calibration set size: {len(calib_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Out:

Dataset initialized with 289 image-mask pairs
Dataset initialized with 288 image-mask pairs
Calibration set size: 289
Test set size: 288

A SemanticSegmentationController is instantiated to control the precision risk (1 - precision) and automatically select a threshold that meets the target precision level with high confidence.

TARGET_PRECISION = 0.7
CONFIDENCE_LEVEL = 0.9
precision_controller = SemanticSegmentationController(
    predict_function=model,
    risk="precision",
    target_level=TARGET_PRECISION,
    confidence_level=CONFIDENCE_LEVEL,
)

print(f"Target precision level: {TARGET_PRECISION}")

Out:

Target precision level: 0.7

During calibration, the controller evaluates the precision risk over a range of thresholds on the calibration dataset in order to identify an optimal decision threshold.

for i, sample in enumerate(calib_loader):
    image, mask = sample["image"], sample["mask"]
    image = image.to(DEVICE)
    mask = mask.cpu().numpy()

    # Filter images that contain masks
    has_mask = mask.sum(axis=(1, 2)) > 0
    image = image[has_mask]
    mask = mask[has_mask]

    if len(image) > 0:
        with torch.no_grad():
            precision_controller.compute_risks(image, mask)

# Compute the best threshold
precision_controller.compute_best_predict_param()
print("Controller calibrated successfully!")
print(f"Optimal threshold found: {precision_controller.best_predict_param[0]:.4f}")

Out:

Controller calibrated successfully!
Optimal threshold found: 0.8700

Controlled predictions are visually inspected on a few test images to illustrate the effect of MAPIE thresholding compared to raw model outputs.

def denormalize_image(tensor_image: torch.Tensor) -> np.ndarray:
    """Denormalize image tensor for visualization."""
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    image = tensor_image.cpu().numpy().transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)

    return image


# Select random test images
NUM_EXAMPLES = 4
np.random.seed(0)

# Get indices of images with masks
indices_with_masks = []
for idx in range(len(test_dataset)):
    sample = test_dataset[idx]
    if sample["mask"].sum() > 0:
        indices_with_masks.append(idx)

random_indices = np.random.choice(indices_with_masks, NUM_EXAMPLES, replace=False)

fig, axes = plt.subplots(2, NUM_EXAMPLES, figsize=(4 * NUM_EXAMPLES, 10))

for col, idx in enumerate(random_indices):
    sample = test_dataset[idx]
    image = sample["image"].unsqueeze(0).to(DEVICE)
    mask = sample["mask"].cpu().numpy()

    with torch.no_grad():
        # Get MAPIE prediction
        mapie_pred = precision_controller.predict(image)[0]

    # Denormalize image
    img_display = denormalize_image(sample["image"])

    # Plot original image (top row)
    axes[0, col].imshow(img_display)
    axes[0, col].set_title("Original Image")
    axes[0, col].axis("off")

    # Plot MAPIE prediction with correct pixels in white and false positives in red (bottom row)
    pred_visualization = np.zeros((*mapie_pred[0].shape, 3))
    true_positives = mask * mapie_pred[0]
    pred_visualization[true_positives > 0] = [1, 1, 1]
    false_positives = (1 - mask) * mapie_pred[0]
    pred_visualization[false_positives > 0] = [1, 0, 0]

    axes[1, col].imshow(pred_visualization)
    axes[1, col].set_title(
        f"MAPIE Prediction (threshold={precision_controller.best_predict_param[0]:.2f})\n"
        "White: Correct | Red: False Positives"
    )
    axes[1, col].axis("off")

plt.tight_layout()
plt.show()

Original Image, Original Image, Original Image, Original Image, MAPIE Prediction (threshold=0.87) White: Correct | Red: False Positives, MAPIE Prediction (threshold=0.87) White: Correct | Red: False Positives, MAPIE Prediction (threshold=0.87) White: Correct | Red: False Positives, MAPIE Prediction (threshold=0.87) White: Correct | Red: False Positives

The controller is finally evaluated on the test set by computing the achieved precision on each image to verify that the target precision level is satisfied on unseen data.

precisions_list = []

for i, sample in enumerate(test_loader):
    image, mask = sample["image"], sample["mask"]
    image = image.to(DEVICE)
    mask = mask.cpu().numpy()

    # Filter images with masks
    has_mask = mask.sum(axis=(1, 2)) > 0
    image = image[has_mask]
    mask = mask[has_mask]

    if len(image) > 0:
        with torch.no_grad():
            pred = precision_controller.predict(image)

            # Compute precision for each image
            for j in range(len(image)):
                tp = (mask[j] * pred[j]).sum()
                fp = ((1 - mask[j]) * pred[j]).sum()
                precision = tp / (tp + fp + 1e-8)
                precisions_list.append(precision)

precisions_array = np.array(precisions_list)

Finally, the distribution of precision values over the test set is plotted to summarize the controlled performance.

fig, ax = plt.subplots(figsize=(10, 6))

ax.hist(precisions_array, bins=30, alpha=0.7, color="steelblue", edgecolor="black")
ax.axvline(
    TARGET_PRECISION,
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Target Precision ({TARGET_PRECISION})",
)
ax.axvline(
    precisions_array.mean(),
    color="orange",
    linestyle="--",
    linewidth=2,
    label=f"Mean Precision ({precisions_array.mean():.3f})",
)
ax.set_xlabel("Precision", fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
ax.set_title("Distribution of Precision on Test Set", fontsize=14, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Distribution of Precision on Test Set

The histogram shows that most test images achieve or exceed the target precision level, illustrating the effectiveness of MAPIE’s risk control for semantic segmentation tasks.

Bootstrap the mean precision over different samplings of the test set (resampling images with replacement).

N_BOOTSTRAP = 2000
BOOTSTRAP_SEED = 123
rng = np.random.default_rng(BOOTSTRAP_SEED)

bootstrap_means = np.empty(N_BOOTSTRAP, dtype=float)
n = precisions_array.size
for b in range(N_BOOTSTRAP):
    bootstrap_sample = rng.choice(precisions_array, size=n, replace=True)
    bootstrap_means[b] = bootstrap_sample.mean()

delta = round(1 - CONFIDENCE_LEVEL, 2)
quantile_confidence = np.quantile(bootstrap_means, delta)
print(f"Bootstrap {delta}-th quantile: {quantile_confidence:.4f}")

fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(
    bootstrap_means,
    bins=40,
    alpha=0.7,
    color="slateblue",
    edgecolor="black",
)
ax.axvline(
    TARGET_PRECISION,
    color="orange",
    linestyle="--",
    linewidth=2,
    label=f"Target precision ({TARGET_PRECISION:.2f})",
)
ax.axvline(
    quantile_confidence,
    color="green",
    linestyle="--",
    linewidth=2,
    label=f"Bootstrap {delta}-th quantile ({quantile_confidence:.3f})",
)
ax.set_xlabel("Bootstrap mean precision", fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
ax.set_title(
    "Bootstrap distribution of mean precision (test set resampling)",
    fontsize=14,
    fontweight="bold",
)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Bootstrap distribution of mean precision (test set resampling)

Out:

Bootstrap 0.1-th quantile: 0.7745

Total running time of the script: ( 1 minutes 23.477 seconds)

Download Python source code: plot_semantic_segmentation_precision_control.py

Download Jupyter notebook: plot_semantic_segmentation_precision_control.ipynb

Gallery generated by mkdocs-gallery