"""Video segmentation and tracking with multiple model architectures.
This module provides a unified interface for loading and running inference with
various video segmentation and tracking models including SAMURAI, SAM2Long,
SAM2.1, and YOLO11n-seg. Models support temporal consistency across frames,
occlusion handling, and mask-based segmentation output.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
import numpy as np
import torch
from PIL import Image
logger = logging.getLogger(__name__)
# Detection thresholds
OCCLUSION_CONFIDENCE_THRESHOLD = 0.5
IOU_MATCH_THRESHOLD = 0.3
LOW_CONFIDENCE_IOU_THRESHOLD = 0.5
[docs]
class TrackingFramework(str, Enum):
"""Supported tracking frameworks for model execution."""
PYTORCH = "pytorch"
ULTRALYTICS = "ultralytics"
SAM2 = "sam2"
[docs]
@dataclass
class TrackingConfig:
"""Configuration for video tracking model loading and inference.
Parameters
----------
model_id : str
HuggingFace model identifier or model name.
framework : TrackingFramework
Framework to use for model execution.
device : str, default="cuda"
Device to load the model on.
cache_dir : Path | None, default=None
Directory for caching model weights.
checkpoint_path : Path | None, default=None
Path to model checkpoint file if using local weights.
"""
model_id: str
framework: TrackingFramework = TrackingFramework.PYTORCH
device: str = "cuda"
cache_dir: Path | None = None
checkpoint_path: Path | None = None
[docs]
@dataclass
class TrackingMask:
"""Segmentation mask for a tracked object.
Parameters
----------
mask : np.ndarray
Binary segmentation mask with shape (H, W) where values are 0 or 1.
confidence : float
Mask prediction confidence score (0.0 to 1.0).
object_id : int
Unique identifier for the tracked object across frames.
"""
mask: np.ndarray[Any, np.dtype[np.uint8]]
confidence: float
object_id: int
[docs]
def to_rle(self) -> dict[str, Any]:
"""Convert mask to Run-Length Encoding format.
Returns
-------
dict[str, Any]
RLE-encoded mask with 'size' and 'counts' keys.
"""
from pycocotools import mask as mask_utils
rle = mask_utils.encode(np.asfortranarray(self.mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8") # type: ignore[index]
return rle # type: ignore[no-any-return]
[docs]
@dataclass
class TrackingFrame:
"""Tracking results for a single video frame.
Parameters
----------
frame_idx : int
Zero-indexed frame number in the video sequence.
masks : list[TrackingMask]
List of segmentation masks for tracked objects in this frame.
occlusions : dict[int, bool]
Mapping of object_id to occlusion status (True if occluded).
processing_time : float
Processing time for this frame in seconds.
"""
frame_idx: int
masks: list[TrackingMask]
occlusions: dict[int, bool]
processing_time: float
[docs]
@dataclass
class TrackingResult:
"""Tracking results for a video sequence.
Parameters
----------
frames : list[TrackingFrame]
Tracking results for each frame in the sequence.
video_width : int
Video frame width in pixels.
video_height : int
Video frame height in pixels.
total_processing_time : float
Total processing time for all frames in seconds.
fps : float
Processing speed in frames per second.
"""
frames: list[TrackingFrame]
video_width: int
video_height: int
total_processing_time: float
fps: float
[docs]
class TrackingModelLoader(ABC):
"""Abstract base class for video tracking model loaders.
All tracking loaders must implement the load and track methods.
"""
[docs]
def __init__(self, config: TrackingConfig) -> None:
"""Initialize the tracking model loader with configuration.
Parameters
----------
config : TrackingConfig
Configuration for model loading and inference.
"""
self.config = config
self.model: Any = None
[docs]
@abstractmethod
def load(self) -> None:
"""Load the tracking model into memory with configured settings.
Raises
------
RuntimeError
If model loading fails.
"""
[docs]
@abstractmethod
def track(
self,
frames: list[Image.Image],
initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]],
object_ids: list[int],
) -> TrackingResult:
"""Track objects across video frames with mask-based segmentation.
Parameters
----------
frames : list[Image.Image]
List of PIL Images representing consecutive video frames.
initial_masks : list[np.ndarray]
Initial segmentation masks for objects in the first frame.
Each mask is a binary numpy array with shape (H, W).
object_ids : list[int]
Unique identifiers for each object to track.
Returns
-------
TrackingResult
Tracking results with segmentation masks for each frame.
Raises
------
RuntimeError
If tracking fails or model is not loaded.
ValueError
If number of initial_masks does not match object_ids length.
"""
[docs]
def unload(self) -> None:
"""Unload the model from memory to free GPU resources."""
if self.model is not None:
del self.model
self.model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Tracking model unloaded and memory cleared")
[docs]
class SAMURAILoader(TrackingModelLoader):
"""Loader for SAMURAI motion-aware tracking model.
SAMURAI achieves 7.1% better performance than SAM2 baseline with
motion-aware tracking and occlusion handling capabilities.
"""
[docs]
def load(self) -> None:
"""Load SAMURAI model with configured settings."""
try:
logger.info(f"Loading SAMURAI from {self.config.model_id}")
# SAMURAI is built on SAM2 architecture with motion awareness
from sam2.build_sam import build_sam2_video_predictor
checkpoint = str(self.config.checkpoint_path) if self.config.checkpoint_path else None
config_file = "sam2_hiera_l.yaml"
self.predictor = build_sam2_video_predictor(
config_file=config_file,
ckpt_path=checkpoint,
device=self.config.device,
)
self.model = self.predictor.model
logger.info("SAMURAI loaded successfully")
except Exception as e:
logger.error(f"Failed to load SAMURAI: {e}")
raise RuntimeError(f"Model loading failed: {e}") from e
[docs]
def track(
self,
frames: list[Image.Image],
initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]],
object_ids: list[int],
) -> TrackingResult:
"""Track objects using SAMURAI with motion-aware tracking."""
if self.model is None:
raise RuntimeError("Model not loaded. Call load() first.")
if len(initial_masks) != len(object_ids):
raise ValueError(
f"Number of initial_masks ({len(initial_masks)}) must match "
f"object_ids length ({len(object_ids)})"
)
import time
try:
start_time = time.time()
height, width = frames[0].size[1], frames[0].size[0]
tracking_frames: list[TrackingFrame] = []
# Initialize inference state
inference_state = self.predictor.init_state(
video=np.array([np.array(f) for f in frames])
)
# Add initial masks for tracking
for obj_id, mask in zip(object_ids, initial_masks, strict=False):
self.predictor.add_new_mask(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
mask=mask,
)
# Propagate masks across frames
for frame_idx in range(len(frames)):
frame_start = time.time()
video_segments = self.predictor.propagate_in_video(
inference_state, start_frame_idx=frame_idx, max_frame_num_to_track=1
)
masks = []
occlusions = {}
for obj_id, obj_masks in video_segments.items():
if frame_idx in obj_masks:
mask_data = obj_masks[frame_idx]
mask = mask_data[0] > 0 # Binary mask
# Convert tensor to numpy if needed
if isinstance(mask, torch.Tensor):
mask = mask.cpu().numpy()
# Detect occlusion based on mask quality
confidence_tensor = (
mask_data[0].max()
if isinstance(mask_data[0], torch.Tensor)
else mask_data[0]
)
if isinstance(confidence_tensor, torch.Tensor):
confidence = float(confidence_tensor.cpu().numpy())
else:
confidence = float(confidence_tensor)
is_occluded = confidence < OCCLUSION_CONFIDENCE_THRESHOLD
masks.append(
TrackingMask(
mask=mask.astype(np.uint8),
confidence=confidence,
object_id=obj_id,
)
)
occlusions[obj_id] = is_occluded
frame_time = time.time() - frame_start
tracking_frames.append(
TrackingFrame(
frame_idx=frame_idx,
masks=masks,
occlusions=occlusions,
processing_time=frame_time,
)
)
total_time = time.time() - start_time
fps = len(frames) / total_time if total_time > 0 else 0.0
return TrackingResult(
frames=tracking_frames,
video_width=width,
video_height=height,
total_processing_time=total_time,
fps=fps,
)
except Exception as e:
logger.error(f"Tracking failed: {e}")
raise RuntimeError(f"Video tracking failed: {e}") from e
[docs]
class SAM2LongLoader(TrackingModelLoader):
"""Loader for SAM2Long long video tracking model.
SAM2Long achieves 5.3% better performance than SAM2 baseline with
error accumulation fixes for long video sequences.
"""
[docs]
def load(self) -> None:
"""Load SAM2Long model with configured settings."""
try:
logger.info(f"Loading SAM2Long from {self.config.model_id}")
from sam2.build_sam import build_sam2_video_predictor
checkpoint = str(self.config.checkpoint_path) if self.config.checkpoint_path else None
config_file = "sam2_hiera_l.yaml"
self.predictor = build_sam2_video_predictor(
config_file=config_file,
ckpt_path=checkpoint,
device=self.config.device,
)
self.model = self.predictor.model
logger.info("SAM2Long loaded successfully")
except Exception as e:
logger.error(f"Failed to load SAM2Long: {e}")
raise RuntimeError(f"Model loading failed: {e}") from e
[docs]
def track(
self,
frames: list[Image.Image],
initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]],
object_ids: list[int],
) -> TrackingResult:
"""Track objects using SAM2Long with error accumulation fixes."""
if self.model is None:
raise RuntimeError("Model not loaded. Call load() first.")
if len(initial_masks) != len(object_ids):
raise ValueError(
f"Number of initial_masks ({len(initial_masks)}) must match "
f"object_ids length ({len(object_ids)})"
)
import time
try:
start_time = time.time()
height, width = frames[0].size[1], frames[0].size[0]
tracking_frames: list[TrackingFrame] = []
# SAM2Long uses memory-efficient propagation for long videos
inference_state = self.predictor.init_state(
video=np.array([np.array(f) for f in frames])
)
for obj_id, mask in zip(object_ids, initial_masks, strict=False):
self.predictor.add_new_mask(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
mask=mask,
)
# Process in chunks to avoid error accumulation
chunk_size = 30 # Process 30 frames at a time
for chunk_start in range(0, len(frames), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(frames))
video_segments = self.predictor.propagate_in_video(
inference_state,
start_frame_idx=chunk_start,
max_frame_num_to_track=chunk_end - chunk_start,
)
for frame_idx in range(chunk_start, chunk_end):
frame_start = time.time()
masks = []
occlusions = {}
for obj_id, obj_masks in video_segments.items():
if frame_idx in obj_masks:
mask_data = obj_masks[frame_idx]
mask = mask_data[0] > 0
# Convert tensor to numpy if needed
if isinstance(mask, torch.Tensor):
mask = mask.cpu().numpy()
confidence_tensor = (
mask_data[0].max()
if isinstance(mask_data[0], torch.Tensor)
else mask_data[0]
)
if isinstance(confidence_tensor, torch.Tensor):
confidence = float(confidence_tensor.cpu().numpy())
else:
confidence = float(confidence_tensor)
is_occluded = confidence < OCCLUSION_CONFIDENCE_THRESHOLD
masks.append(
TrackingMask(
mask=mask.astype(np.uint8),
confidence=confidence,
object_id=obj_id,
)
)
occlusions[obj_id] = is_occluded
frame_time = time.time() - frame_start
tracking_frames.append(
TrackingFrame(
frame_idx=frame_idx,
masks=masks,
occlusions=occlusions,
processing_time=frame_time,
)
)
total_time = time.time() - start_time
fps = len(frames) / total_time if total_time > 0 else 0.0
return TrackingResult(
frames=tracking_frames,
video_width=width,
video_height=height,
total_processing_time=total_time,
fps=fps,
)
except Exception as e:
logger.error(f"Tracking failed: {e}")
raise RuntimeError(f"Video tracking failed: {e}") from e
[docs]
class SAM2Loader(TrackingModelLoader):
"""Loader for SAM2.1 baseline video segmentation model.
SAM2.1 provides baseline performance with proven stability for
general video segmentation and tracking tasks.
"""
[docs]
def load(self) -> None:
"""Load SAM2.1 model with configured settings."""
try:
logger.info(f"Loading SAM2.1 from {self.config.model_id}")
from sam2.build_sam import build_sam2_video_predictor
checkpoint = str(self.config.checkpoint_path) if self.config.checkpoint_path else None
config_file = "sam2_hiera_l.yaml"
self.predictor = build_sam2_video_predictor(
config_file=config_file,
ckpt_path=checkpoint,
device=self.config.device,
)
self.model = self.predictor.model
logger.info("SAM2.1 loaded successfully")
except Exception as e:
logger.error(f"Failed to load SAM2.1: {e}")
raise RuntimeError(f"Model loading failed: {e}") from e
[docs]
def track(
self,
frames: list[Image.Image],
initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]],
object_ids: list[int],
) -> TrackingResult:
"""Track objects using SAM2.1 baseline implementation."""
if self.model is None:
raise RuntimeError("Model not loaded. Call load() first.")
if len(initial_masks) != len(object_ids):
raise ValueError(
f"Number of initial_masks ({len(initial_masks)}) must match "
f"object_ids length ({len(object_ids)})"
)
import time
try:
start_time = time.time()
height, width = frames[0].size[1], frames[0].size[0]
tracking_frames: list[TrackingFrame] = []
inference_state = self.predictor.init_state(
video=np.array([np.array(f) for f in frames])
)
for obj_id, mask in zip(object_ids, initial_masks, strict=False):
self.predictor.add_new_mask(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
mask=mask,
)
video_segments = self.predictor.propagate_in_video(inference_state)
for frame_idx in range(len(frames)):
frame_start = time.time()
masks = []
occlusions = {}
for obj_id, obj_masks in video_segments.items():
if frame_idx in obj_masks:
mask_data = obj_masks[frame_idx]
mask = mask_data[0] > 0
# Convert tensor to numpy if needed
if isinstance(mask, torch.Tensor):
mask = mask.cpu().numpy()
confidence_tensor = (
mask_data[0].max()
if isinstance(mask_data[0], torch.Tensor)
else mask_data[0]
)
if isinstance(confidence_tensor, torch.Tensor):
confidence = float(confidence_tensor.cpu().numpy())
else:
confidence = float(confidence_tensor)
is_occluded = False # SAM2.1 baseline does not detect occlusion
masks.append(
TrackingMask(
mask=mask.astype(np.uint8),
confidence=confidence,
object_id=obj_id,
)
)
occlusions[obj_id] = is_occluded
frame_time = time.time() - frame_start
tracking_frames.append(
TrackingFrame(
frame_idx=frame_idx,
masks=masks,
occlusions=occlusions,
processing_time=frame_time,
)
)
total_time = time.time() - start_time
fps = len(frames) / total_time if total_time > 0 else 0.0
return TrackingResult(
frames=tracking_frames,
video_width=width,
video_height=height,
total_processing_time=total_time,
fps=fps,
)
except Exception as e:
logger.error(f"Tracking failed: {e}")
raise RuntimeError(f"Video tracking failed: {e}") from e
[docs]
class YOLO11SegLoader(TrackingModelLoader):
"""Loader for YOLO11n-seg lightweight segmentation model.
YOLO11n-seg is a 2.7M parameter model optimized for real-time
segmentation in speed-critical applications.
"""
[docs]
def load(self) -> None:
"""Load YOLO11n-seg model with configured settings."""
try:
from ultralytics import YOLO # type: ignore[attr-defined]
logger.info(f"Loading YOLO11n-seg from {self.config.model_id}")
self.model = YOLO(self.config.model_id)
if torch.cuda.is_available():
self.model.to(self.config.device)
logger.info("YOLO11n-seg loaded successfully")
except Exception as e:
logger.error(f"Failed to load YOLO11n-seg: {e}")
raise RuntimeError(f"Model loading failed: {e}") from e
[docs]
def track( # noqa: PLR0912
self,
frames: list[Image.Image],
initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]],
object_ids: list[int],
) -> TrackingResult:
"""Track objects using YOLO11n-seg with per-frame segmentation.
Note: YOLO11n-seg performs independent segmentation per frame without
temporal consistency. Object re-identification is based on spatial overlap.
"""
if self.model is None:
raise RuntimeError("Model not loaded. Call load() first.")
if len(initial_masks) != len(object_ids):
raise ValueError(
f"Number of initial_masks ({len(initial_masks)}) must match "
f"object_ids length ({len(object_ids)})"
)
import time
try:
start_time = time.time()
height, width = frames[0].size[1], frames[0].size[0]
tracking_frames: list[TrackingFrame] = []
# Track objects based on spatial overlap
prev_masks = dict(zip(object_ids, initial_masks, strict=False))
for frame_idx, frame in enumerate(frames):
frame_start = time.time()
# Run segmentation
results = self.model(np.array(frame), verbose=False)[0]
masks = []
occlusions = {}
if results.masks is not None:
# Match detected masks to tracked objects by IoU
detected_masks = results.masks.data.cpu().numpy()
for obj_id in object_ids:
if obj_id in prev_masks:
# Find best matching mask by IoU
best_iou = 0.0
best_mask = None
for det_mask in detected_masks:
# Resize detected mask to match frame dimensions if needed
if det_mask.shape != (height, width):
import cv2
det_mask_resized = cv2.resize(
det_mask.astype(np.uint8),
(width, height),
interpolation=cv2.INTER_NEAREST,
).astype(np.float32)
else:
det_mask_resized = det_mask
iou = self._compute_iou(prev_masks[obj_id], det_mask_resized)
if iou > best_iou:
best_iou = iou
best_mask = det_mask_resized
if best_mask is not None and best_iou > IOU_MATCH_THRESHOLD:
masks.append(
TrackingMask(
mask=best_mask.astype(np.uint8),
confidence=best_iou,
object_id=obj_id,
)
)
prev_masks[obj_id] = best_mask
occlusions[obj_id] = best_iou < LOW_CONFIDENCE_IOU_THRESHOLD
else:
# Object lost (occluded or out of frame)
occlusions[obj_id] = True
frame_time = time.time() - frame_start
tracking_frames.append(
TrackingFrame(
frame_idx=frame_idx,
masks=masks,
occlusions=occlusions,
processing_time=frame_time,
)
)
total_time = time.time() - start_time
fps = len(frames) / total_time if total_time > 0 else 0.0
return TrackingResult(
frames=tracking_frames,
video_width=width,
video_height=height,
total_processing_time=total_time,
fps=fps,
)
except Exception as e:
logger.error(f"Tracking failed: {e}")
raise RuntimeError(f"Video tracking failed: {e}") from e
def _compute_iou(
self,
mask1: np.ndarray[Any, np.dtype[Any]],
mask2: np.ndarray[Any, np.dtype[Any]],
) -> float:
"""Compute Intersection over Union between two masks.
Parameters
----------
mask1 : np.ndarray
First binary mask.
mask2 : np.ndarray
Second binary mask.
Returns
-------
float
IoU score between 0.0 and 1.0.
"""
intersection = np.logical_and(mask1 > 0, mask2 > 0).sum()
union = np.logical_or(mask1 > 0, mask2 > 0).sum()
if union == 0:
return 0.0
return float(intersection / union)
[docs]
def create_tracking_loader(model_name: str, config: TrackingConfig) -> TrackingModelLoader:
"""Factory function to create appropriate tracking loader based on model name.
Parameters
----------
model_name : str
Name of the model to load. Supported values:
- "samurai" (default)
- "sam2long" or "sam2-long"
- "sam2" or "sam2.1"
- "yolo11n-seg" or "yolo11seg"
config : TrackingConfig
Configuration for model loading and inference.
Returns
-------
TrackingModelLoader
Appropriate loader instance for the specified model.
Raises
------
ValueError
If model_name is not recognized.
"""
model_name_lower = model_name.lower().replace("_", "-")
if "samurai" in model_name_lower:
return SAMURAILoader(config)
if "sam2long" in model_name_lower or "sam2-long" in model_name_lower:
return SAM2LongLoader(config)
if "sam2" in model_name_lower:
return SAM2Loader(config)
if "yolo11" in model_name_lower and "seg" in model_name_lower:
return YOLO11SegLoader(config)
raise ValueError(
f"Unknown model name: {model_name}. Supported models: samurai, sam2long, sam2, yolo11n-seg"
)