Source code for tracking_loader

"""Video segmentation and tracking with multiple model architectures.

This module provides a unified interface for loading and running inference with
various video segmentation and tracking models including SAMURAI, SAM2Long,
SAM2.1, and YOLO11n-seg. Models support temporal consistency across frames,
occlusion handling, and mask-based segmentation output.
"""

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any

import numpy as np
import torch
from PIL import Image

logger = logging.getLogger(__name__)

# Detection thresholds
OCCLUSION_CONFIDENCE_THRESHOLD = 0.5
IOU_MATCH_THRESHOLD = 0.3
LOW_CONFIDENCE_IOU_THRESHOLD = 0.5


[docs] class TrackingFramework(str, Enum): """Supported tracking frameworks for model execution.""" PYTORCH = "pytorch" ULTRALYTICS = "ultralytics" SAM2 = "sam2"
[docs] @dataclass class TrackingConfig: """Configuration for video tracking model loading and inference. Parameters ---------- model_id : str HuggingFace model identifier or model name. framework : TrackingFramework Framework to use for model execution. device : str, default="cuda" Device to load the model on. cache_dir : Path | None, default=None Directory for caching model weights. checkpoint_path : Path | None, default=None Path to model checkpoint file if using local weights. """ model_id: str framework: TrackingFramework = TrackingFramework.PYTORCH device: str = "cuda" cache_dir: Path | None = None checkpoint_path: Path | None = None
[docs] @dataclass class TrackingMask: """Segmentation mask for a tracked object. Parameters ---------- mask : np.ndarray Binary segmentation mask with shape (H, W) where values are 0 or 1. confidence : float Mask prediction confidence score (0.0 to 1.0). object_id : int Unique identifier for the tracked object across frames. """ mask: np.ndarray[Any, np.dtype[np.uint8]] confidence: float object_id: int
[docs] def to_rle(self) -> dict[str, Any]: """Convert mask to Run-Length Encoding format. Returns ------- dict[str, Any] RLE-encoded mask with 'size' and 'counts' keys. """ from pycocotools import mask as mask_utils rle = mask_utils.encode(np.asfortranarray(self.mask.astype(np.uint8))) rle["counts"] = rle["counts"].decode("utf-8") # type: ignore[index] return rle # type: ignore[no-any-return]
[docs] @dataclass class TrackingFrame: """Tracking results for a single video frame. Parameters ---------- frame_idx : int Zero-indexed frame number in the video sequence. masks : list[TrackingMask] List of segmentation masks for tracked objects in this frame. occlusions : dict[int, bool] Mapping of object_id to occlusion status (True if occluded). processing_time : float Processing time for this frame in seconds. """ frame_idx: int masks: list[TrackingMask] occlusions: dict[int, bool] processing_time: float
[docs] @dataclass class TrackingResult: """Tracking results for a video sequence. Parameters ---------- frames : list[TrackingFrame] Tracking results for each frame in the sequence. video_width : int Video frame width in pixels. video_height : int Video frame height in pixels. total_processing_time : float Total processing time for all frames in seconds. fps : float Processing speed in frames per second. """ frames: list[TrackingFrame] video_width: int video_height: int total_processing_time: float fps: float
[docs] class TrackingModelLoader(ABC): """Abstract base class for video tracking model loaders. All tracking loaders must implement the load and track methods. """
[docs] def __init__(self, config: TrackingConfig) -> None: """Initialize the tracking model loader with configuration. Parameters ---------- config : TrackingConfig Configuration for model loading and inference. """ self.config = config self.model: Any = None
[docs] @abstractmethod def load(self) -> None: """Load the tracking model into memory with configured settings. Raises ------ RuntimeError If model loading fails. """
[docs] @abstractmethod def track( self, frames: list[Image.Image], initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]], object_ids: list[int], ) -> TrackingResult: """Track objects across video frames with mask-based segmentation. Parameters ---------- frames : list[Image.Image] List of PIL Images representing consecutive video frames. initial_masks : list[np.ndarray] Initial segmentation masks for objects in the first frame. Each mask is a binary numpy array with shape (H, W). object_ids : list[int] Unique identifiers for each object to track. Returns ------- TrackingResult Tracking results with segmentation masks for each frame. Raises ------ RuntimeError If tracking fails or model is not loaded. ValueError If number of initial_masks does not match object_ids length. """
[docs] def unload(self) -> None: """Unload the model from memory to free GPU resources.""" if self.model is not None: del self.model self.model = None if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Tracking model unloaded and memory cleared")
[docs] class SAMURAILoader(TrackingModelLoader): """Loader for SAMURAI motion-aware tracking model. SAMURAI achieves 7.1% better performance than SAM2 baseline with motion-aware tracking and occlusion handling capabilities. """
[docs] def load(self) -> None: """Load SAMURAI model with configured settings.""" try: logger.info(f"Loading SAMURAI from {self.config.model_id}") # SAMURAI is built on SAM2 architecture with motion awareness from sam2.build_sam import build_sam2_video_predictor checkpoint = str(self.config.checkpoint_path) if self.config.checkpoint_path else None config_file = "sam2_hiera_l.yaml" self.predictor = build_sam2_video_predictor( config_file=config_file, ckpt_path=checkpoint, device=self.config.device, ) self.model = self.predictor.model logger.info("SAMURAI loaded successfully") except Exception as e: logger.error(f"Failed to load SAMURAI: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def track( self, frames: list[Image.Image], initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]], object_ids: list[int], ) -> TrackingResult: """Track objects using SAMURAI with motion-aware tracking.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") if len(initial_masks) != len(object_ids): raise ValueError( f"Number of initial_masks ({len(initial_masks)}) must match " f"object_ids length ({len(object_ids)})" ) import time try: start_time = time.time() height, width = frames[0].size[1], frames[0].size[0] tracking_frames: list[TrackingFrame] = [] # Initialize inference state inference_state = self.predictor.init_state( video=np.array([np.array(f) for f in frames]) ) # Add initial masks for tracking for obj_id, mask in zip(object_ids, initial_masks, strict=False): self.predictor.add_new_mask( inference_state=inference_state, frame_idx=0, obj_id=obj_id, mask=mask, ) # Propagate masks across frames for frame_idx in range(len(frames)): frame_start = time.time() video_segments = self.predictor.propagate_in_video( inference_state, start_frame_idx=frame_idx, max_frame_num_to_track=1 ) masks = [] occlusions = {} for obj_id, obj_masks in video_segments.items(): if frame_idx in obj_masks: mask_data = obj_masks[frame_idx] mask = mask_data[0] > 0 # Binary mask # Convert tensor to numpy if needed if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() # Detect occlusion based on mask quality confidence_tensor = ( mask_data[0].max() if isinstance(mask_data[0], torch.Tensor) else mask_data[0] ) if isinstance(confidence_tensor, torch.Tensor): confidence = float(confidence_tensor.cpu().numpy()) else: confidence = float(confidence_tensor) is_occluded = confidence < OCCLUSION_CONFIDENCE_THRESHOLD masks.append( TrackingMask( mask=mask.astype(np.uint8), confidence=confidence, object_id=obj_id, ) ) occlusions[obj_id] = is_occluded frame_time = time.time() - frame_start tracking_frames.append( TrackingFrame( frame_idx=frame_idx, masks=masks, occlusions=occlusions, processing_time=frame_time, ) ) total_time = time.time() - start_time fps = len(frames) / total_time if total_time > 0 else 0.0 return TrackingResult( frames=tracking_frames, video_width=width, video_height=height, total_processing_time=total_time, fps=fps, ) except Exception as e: logger.error(f"Tracking failed: {e}") raise RuntimeError(f"Video tracking failed: {e}") from e
[docs] class SAM2LongLoader(TrackingModelLoader): """Loader for SAM2Long long video tracking model. SAM2Long achieves 5.3% better performance than SAM2 baseline with error accumulation fixes for long video sequences. """
[docs] def load(self) -> None: """Load SAM2Long model with configured settings.""" try: logger.info(f"Loading SAM2Long from {self.config.model_id}") from sam2.build_sam import build_sam2_video_predictor checkpoint = str(self.config.checkpoint_path) if self.config.checkpoint_path else None config_file = "sam2_hiera_l.yaml" self.predictor = build_sam2_video_predictor( config_file=config_file, ckpt_path=checkpoint, device=self.config.device, ) self.model = self.predictor.model logger.info("SAM2Long loaded successfully") except Exception as e: logger.error(f"Failed to load SAM2Long: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def track( self, frames: list[Image.Image], initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]], object_ids: list[int], ) -> TrackingResult: """Track objects using SAM2Long with error accumulation fixes.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") if len(initial_masks) != len(object_ids): raise ValueError( f"Number of initial_masks ({len(initial_masks)}) must match " f"object_ids length ({len(object_ids)})" ) import time try: start_time = time.time() height, width = frames[0].size[1], frames[0].size[0] tracking_frames: list[TrackingFrame] = [] # SAM2Long uses memory-efficient propagation for long videos inference_state = self.predictor.init_state( video=np.array([np.array(f) for f in frames]) ) for obj_id, mask in zip(object_ids, initial_masks, strict=False): self.predictor.add_new_mask( inference_state=inference_state, frame_idx=0, obj_id=obj_id, mask=mask, ) # Process in chunks to avoid error accumulation chunk_size = 30 # Process 30 frames at a time for chunk_start in range(0, len(frames), chunk_size): chunk_end = min(chunk_start + chunk_size, len(frames)) video_segments = self.predictor.propagate_in_video( inference_state, start_frame_idx=chunk_start, max_frame_num_to_track=chunk_end - chunk_start, ) for frame_idx in range(chunk_start, chunk_end): frame_start = time.time() masks = [] occlusions = {} for obj_id, obj_masks in video_segments.items(): if frame_idx in obj_masks: mask_data = obj_masks[frame_idx] mask = mask_data[0] > 0 # Convert tensor to numpy if needed if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() confidence_tensor = ( mask_data[0].max() if isinstance(mask_data[0], torch.Tensor) else mask_data[0] ) if isinstance(confidence_tensor, torch.Tensor): confidence = float(confidence_tensor.cpu().numpy()) else: confidence = float(confidence_tensor) is_occluded = confidence < OCCLUSION_CONFIDENCE_THRESHOLD masks.append( TrackingMask( mask=mask.astype(np.uint8), confidence=confidence, object_id=obj_id, ) ) occlusions[obj_id] = is_occluded frame_time = time.time() - frame_start tracking_frames.append( TrackingFrame( frame_idx=frame_idx, masks=masks, occlusions=occlusions, processing_time=frame_time, ) ) total_time = time.time() - start_time fps = len(frames) / total_time if total_time > 0 else 0.0 return TrackingResult( frames=tracking_frames, video_width=width, video_height=height, total_processing_time=total_time, fps=fps, ) except Exception as e: logger.error(f"Tracking failed: {e}") raise RuntimeError(f"Video tracking failed: {e}") from e
[docs] class SAM2Loader(TrackingModelLoader): """Loader for SAM2.1 baseline video segmentation model. SAM2.1 provides baseline performance with proven stability for general video segmentation and tracking tasks. """
[docs] def load(self) -> None: """Load SAM2.1 model with configured settings.""" try: logger.info(f"Loading SAM2.1 from {self.config.model_id}") from sam2.build_sam import build_sam2_video_predictor checkpoint = str(self.config.checkpoint_path) if self.config.checkpoint_path else None config_file = "sam2_hiera_l.yaml" self.predictor = build_sam2_video_predictor( config_file=config_file, ckpt_path=checkpoint, device=self.config.device, ) self.model = self.predictor.model logger.info("SAM2.1 loaded successfully") except Exception as e: logger.error(f"Failed to load SAM2.1: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def track( self, frames: list[Image.Image], initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]], object_ids: list[int], ) -> TrackingResult: """Track objects using SAM2.1 baseline implementation.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") if len(initial_masks) != len(object_ids): raise ValueError( f"Number of initial_masks ({len(initial_masks)}) must match " f"object_ids length ({len(object_ids)})" ) import time try: start_time = time.time() height, width = frames[0].size[1], frames[0].size[0] tracking_frames: list[TrackingFrame] = [] inference_state = self.predictor.init_state( video=np.array([np.array(f) for f in frames]) ) for obj_id, mask in zip(object_ids, initial_masks, strict=False): self.predictor.add_new_mask( inference_state=inference_state, frame_idx=0, obj_id=obj_id, mask=mask, ) video_segments = self.predictor.propagate_in_video(inference_state) for frame_idx in range(len(frames)): frame_start = time.time() masks = [] occlusions = {} for obj_id, obj_masks in video_segments.items(): if frame_idx in obj_masks: mask_data = obj_masks[frame_idx] mask = mask_data[0] > 0 # Convert tensor to numpy if needed if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() confidence_tensor = ( mask_data[0].max() if isinstance(mask_data[0], torch.Tensor) else mask_data[0] ) if isinstance(confidence_tensor, torch.Tensor): confidence = float(confidence_tensor.cpu().numpy()) else: confidence = float(confidence_tensor) is_occluded = False # SAM2.1 baseline does not detect occlusion masks.append( TrackingMask( mask=mask.astype(np.uint8), confidence=confidence, object_id=obj_id, ) ) occlusions[obj_id] = is_occluded frame_time = time.time() - frame_start tracking_frames.append( TrackingFrame( frame_idx=frame_idx, masks=masks, occlusions=occlusions, processing_time=frame_time, ) ) total_time = time.time() - start_time fps = len(frames) / total_time if total_time > 0 else 0.0 return TrackingResult( frames=tracking_frames, video_width=width, video_height=height, total_processing_time=total_time, fps=fps, ) except Exception as e: logger.error(f"Tracking failed: {e}") raise RuntimeError(f"Video tracking failed: {e}") from e
[docs] class YOLO11SegLoader(TrackingModelLoader): """Loader for YOLO11n-seg lightweight segmentation model. YOLO11n-seg is a 2.7M parameter model optimized for real-time segmentation in speed-critical applications. """
[docs] def load(self) -> None: """Load YOLO11n-seg model with configured settings.""" try: from ultralytics import YOLO # type: ignore[attr-defined] logger.info(f"Loading YOLO11n-seg from {self.config.model_id}") self.model = YOLO(self.config.model_id) if torch.cuda.is_available(): self.model.to(self.config.device) logger.info("YOLO11n-seg loaded successfully") except Exception as e: logger.error(f"Failed to load YOLO11n-seg: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def track( # noqa: PLR0912 self, frames: list[Image.Image], initial_masks: list[np.ndarray[Any, np.dtype[np.uint8]]], object_ids: list[int], ) -> TrackingResult: """Track objects using YOLO11n-seg with per-frame segmentation. Note: YOLO11n-seg performs independent segmentation per frame without temporal consistency. Object re-identification is based on spatial overlap. """ if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") if len(initial_masks) != len(object_ids): raise ValueError( f"Number of initial_masks ({len(initial_masks)}) must match " f"object_ids length ({len(object_ids)})" ) import time try: start_time = time.time() height, width = frames[0].size[1], frames[0].size[0] tracking_frames: list[TrackingFrame] = [] # Track objects based on spatial overlap prev_masks = dict(zip(object_ids, initial_masks, strict=False)) for frame_idx, frame in enumerate(frames): frame_start = time.time() # Run segmentation results = self.model(np.array(frame), verbose=False)[0] masks = [] occlusions = {} if results.masks is not None: # Match detected masks to tracked objects by IoU detected_masks = results.masks.data.cpu().numpy() for obj_id in object_ids: if obj_id in prev_masks: # Find best matching mask by IoU best_iou = 0.0 best_mask = None for det_mask in detected_masks: # Resize detected mask to match frame dimensions if needed if det_mask.shape != (height, width): import cv2 det_mask_resized = cv2.resize( det_mask.astype(np.uint8), (width, height), interpolation=cv2.INTER_NEAREST, ).astype(np.float32) else: det_mask_resized = det_mask iou = self._compute_iou(prev_masks[obj_id], det_mask_resized) if iou > best_iou: best_iou = iou best_mask = det_mask_resized if best_mask is not None and best_iou > IOU_MATCH_THRESHOLD: masks.append( TrackingMask( mask=best_mask.astype(np.uint8), confidence=best_iou, object_id=obj_id, ) ) prev_masks[obj_id] = best_mask occlusions[obj_id] = best_iou < LOW_CONFIDENCE_IOU_THRESHOLD else: # Object lost (occluded or out of frame) occlusions[obj_id] = True frame_time = time.time() - frame_start tracking_frames.append( TrackingFrame( frame_idx=frame_idx, masks=masks, occlusions=occlusions, processing_time=frame_time, ) ) total_time = time.time() - start_time fps = len(frames) / total_time if total_time > 0 else 0.0 return TrackingResult( frames=tracking_frames, video_width=width, video_height=height, total_processing_time=total_time, fps=fps, ) except Exception as e: logger.error(f"Tracking failed: {e}") raise RuntimeError(f"Video tracking failed: {e}") from e
def _compute_iou( self, mask1: np.ndarray[Any, np.dtype[Any]], mask2: np.ndarray[Any, np.dtype[Any]], ) -> float: """Compute Intersection over Union between two masks. Parameters ---------- mask1 : np.ndarray First binary mask. mask2 : np.ndarray Second binary mask. Returns ------- float IoU score between 0.0 and 1.0. """ intersection = np.logical_and(mask1 > 0, mask2 > 0).sum() union = np.logical_or(mask1 > 0, mask2 > 0).sum() if union == 0: return 0.0 return float(intersection / union)
[docs] def create_tracking_loader(model_name: str, config: TrackingConfig) -> TrackingModelLoader: """Factory function to create appropriate tracking loader based on model name. Parameters ---------- model_name : str Name of the model to load. Supported values: - "samurai" (default) - "sam2long" or "sam2-long" - "sam2" or "sam2.1" - "yolo11n-seg" or "yolo11seg" config : TrackingConfig Configuration for model loading and inference. Returns ------- TrackingModelLoader Appropriate loader instance for the specified model. Raises ------ ValueError If model_name is not recognized. """ model_name_lower = model_name.lower().replace("_", "-") if "samurai" in model_name_lower: return SAMURAILoader(config) if "sam2long" in model_name_lower or "sam2-long" in model_name_lower: return SAM2LongLoader(config) if "sam2" in model_name_lower: return SAM2Loader(config) if "yolo11" in model_name_lower and "seg" in model_name_lower: return YOLO11SegLoader(config) raise ValueError( f"Unknown model name: {model_name}. Supported models: samurai, sam2long, sam2, yolo11n-seg" )