Source code for detection_loader

"""Open-vocabulary object detection with multiple model architectures.

This module provides a unified interface for loading and running inference with
various open-vocabulary object detection models including YOLO-World v2.1,
Grounding DINO 1.5, OWLv2, and Florence-2. Models support text-based prompts
for detecting objects without pre-defined class vocabularies.
"""

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any

import numpy as np
import torch
from PIL import Image

logger = logging.getLogger(__name__)


[docs] class DetectionFramework(str, Enum): """Supported detection frameworks for model execution.""" PYTORCH = "pytorch" ULTRALYTICS = "ultralytics" TRANSFORMERS = "transformers"
[docs] @dataclass class DetectionConfig: """Configuration for object detection model loading and inference. Parameters ---------- model_id : str HuggingFace model identifier or Ultralytics model name. framework : DetectionFramework Framework to use for model execution. confidence_threshold : float, default=0.25 Minimum confidence score for detections (0.0 to 1.0). device : str, default="cuda" Device to load the model on. cache_dir : Path | None, default=None Directory for caching model weights. """ model_id: str framework: DetectionFramework = DetectionFramework.PYTORCH confidence_threshold: float = 0.25 device: str = "cuda" cache_dir: Path | None = None
[docs] @dataclass class BoundingBox: """Bounding box in normalized coordinates. Parameters ---------- x1 : float Left coordinate (0.0 to 1.0, normalized by image width). y1 : float Top coordinate (0.0 to 1.0, normalized by image height). x2 : float Right coordinate (0.0 to 1.0, normalized by image width). y2 : float Bottom coordinate (0.0 to 1.0, normalized by image height). """ x1: float y1: float x2: float y2: float
[docs] def to_absolute(self, width: int, height: int) -> tuple[int, int, int, int]: """Convert normalized coordinates to absolute pixel coordinates. Parameters ---------- width : int Image width in pixels. height : int Image height in pixels. Returns ------- tuple[int, int, int, int] Bounding box in absolute coordinates (x1, y1, x2, y2). """ return ( int(self.x1 * width), int(self.y1 * height), int(self.x2 * width), int(self.y2 * height), )
[docs] @dataclass class Detection: """Single object detection result. Parameters ---------- bbox : BoundingBox Bounding box in normalized coordinates. confidence : float Detection confidence score (0.0 to 1.0). label : str Detected object class or description. """ bbox: BoundingBox confidence: float label: str
[docs] @dataclass class DetectionResult: """Detection results for a single image. Parameters ---------- detections : list[Detection] List of detected objects with bounding boxes and scores. image_width : int Original image width in pixels. image_height : int Original image height in pixels. processing_time : float Processing time in seconds. """ detections: list[Detection] image_width: int image_height: int processing_time: float
[docs] class DetectionModelLoader(ABC): """Abstract base class for object detection model loaders. All detection loaders must implement the load and detect methods. """
[docs] def __init__(self, config: DetectionConfig) -> None: """Initialize the detection model loader with configuration. Parameters ---------- config : DetectionConfig Configuration for model loading and inference. """ self.config = config self.model: Any = None
[docs] @abstractmethod def load(self) -> None: """Load the detection model into memory with configured settings. Raises ------ RuntimeError If model loading fails. """
[docs] @abstractmethod def detect( self, image: Image.Image, text_prompt: str, ) -> DetectionResult: """Detect objects in an image based on text prompt. Parameters ---------- image : Image.Image PIL Image to process. text_prompt : str Text description of objects to detect (e.g., "person. car. dog."). Returns ------- DetectionResult Detection results with bounding boxes in normalized coordinates. Raises ------ RuntimeError If detection fails or model is not loaded. """
[docs] def unload(self) -> None: """Unload the model from memory to free GPU resources.""" if self.model is not None: del self.model self.model = None if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Detection model unloaded and memory cleared")
[docs] class YOLOWorldLoader(DetectionModelLoader): """Loader for YOLO-World v2.1 open-vocabulary detection model. YOLO-World v2.1 achieves real-time performance (52 FPS) with strong accuracy on open-vocabulary object detection tasks. """
[docs] def load(self) -> None: """Load YOLO-World v2.1 model with configured settings.""" try: from ultralytics import YOLO # type: ignore[attr-defined] logger.info(f"Loading YOLO-World v2.1 from {self.config.model_id}") self.model = YOLO(self.config.model_id) if torch.cuda.is_available(): self.model.to(self.config.device) logger.info("YOLO-World v2.1 loaded successfully") except Exception as e: logger.error(f"Failed to load YOLO-World v2.1: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def detect( self, image: Image.Image, text_prompt: str, ) -> DetectionResult: """Detect objects using YOLO-World v2.1 with text prompts.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") import time try: start_time = time.time() image_array = np.array(image) height, width = image_array.shape[:2] self.model.set_classes([c.strip() for c in text_prompt.split(".")]) results = self.model(image_array, verbose=False)[0] detections = [] if results.boxes is not None: for box in results.boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() conf = float(box.conf[0].cpu().numpy()) if conf >= self.config.confidence_threshold: cls_id = int(box.cls[0].cpu().numpy()) label = self.model.names[cls_id] bbox = BoundingBox( x1=float(x1) / width, y1=float(y1) / height, x2=float(x2) / width, y2=float(y2) / height, ) detections.append(Detection(bbox=bbox, confidence=conf, label=label)) processing_time = time.time() - start_time return DetectionResult( detections=detections, image_width=width, image_height=height, processing_time=processing_time, ) except Exception as e: logger.error(f"Detection failed: {e}") raise RuntimeError(f"Object detection failed: {e}") from e
[docs] class GroundingDINOLoader(DetectionModelLoader): """Loader for Grounding DINO 1.5 open-vocabulary detection model. Grounding DINO 1.5 achieves 52.5 AP on COCO with zero-shot open-world object detection capabilities. """
[docs] def load(self) -> None: """Load Grounding DINO 1.5 model with configured settings.""" try: from groundingdino.util.inference import load_model logger.info(f"Loading Grounding DINO 1.5 from {self.config.model_id}") config_path = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" weights_path = self.config.model_id self.model = load_model(config_path, weights_path) if torch.cuda.is_available(): self.model.to(self.config.device) logger.info("Grounding DINO 1.5 loaded successfully") except Exception as e: logger.error(f"Failed to load Grounding DINO 1.5: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def detect( self, image: Image.Image, text_prompt: str, ) -> DetectionResult: """Detect objects using Grounding DINO 1.5 with text prompts.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") import time try: from groundingdino.util.inference import predict start_time = time.time() image_array = np.array(image) height, width = image_array.shape[:2] boxes, logits, phrases = predict( model=self.model, image=image, caption=text_prompt, box_threshold=self.config.confidence_threshold, text_threshold=0.25, ) detections = [] for box, conf, phrase in zip(boxes, logits, phrases, strict=False): x_center, y_center, w, h = box.cpu().numpy() x1 = float(x_center - w / 2) y1 = float(y_center - h / 2) x2 = float(x_center + w / 2) y2 = float(y_center + h / 2) bbox = BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2) detections.append(Detection(bbox=bbox, confidence=float(conf), label=phrase)) processing_time = time.time() - start_time return DetectionResult( detections=detections, image_width=width, image_height=height, processing_time=processing_time, ) except Exception as e: logger.error(f"Detection failed: {e}") raise RuntimeError(f"Object detection failed: {e}") from e
[docs] class OWLv2Loader(DetectionModelLoader): """Loader for OWLv2 open-vocabulary detection model. OWLv2 uses scaled training data and achieves strong performance on rare and novel object classes. """
[docs] def load(self) -> None: """Load OWLv2 model with configured settings.""" try: from transformers import Owlv2ForObjectDetection, Owlv2Processor logger.info(f"Loading OWLv2 from {self.config.model_id}") self.processor = Owlv2Processor.from_pretrained( self.config.model_id, cache_dir=str(self.config.cache_dir) if self.config.cache_dir else None, ) self.model = Owlv2ForObjectDetection.from_pretrained( self.config.model_id, cache_dir=str(self.config.cache_dir) if self.config.cache_dir else None, ) if torch.cuda.is_available(): self.model.to(self.config.device) self.model.eval() logger.info("OWLv2 loaded successfully") except Exception as e: logger.error(f"Failed to load OWLv2: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def detect( self, image: Image.Image, text_prompt: str, ) -> DetectionResult: """Detect objects using OWLv2 with text prompts.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") import time try: start_time = time.time() width, height = image.size text_queries = [c.strip() for c in text_prompt.split(".") if c.strip()] inputs = self.processor(text=text_queries, images=image, return_tensors="pt") inputs = {k: v.to(self.config.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) target_sizes = torch.tensor([[height, width]]).to(self.config.device) results = self.processor.post_process_object_detection( outputs=outputs, threshold=self.config.confidence_threshold, target_sizes=target_sizes, )[0] detections = [] for box, score, label_idx in zip( results["boxes"], results["scores"], results["labels"], strict=False ): x1, y1, x2, y2 = box.cpu().numpy() bbox = BoundingBox( x1=float(x1) / width, y1=float(y1) / height, x2=float(x2) / width, y2=float(y2) / height, ) label = text_queries[int(label_idx)] detections.append(Detection(bbox=bbox, confidence=float(score), label=label)) processing_time = time.time() - start_time return DetectionResult( detections=detections, image_width=width, image_height=height, processing_time=processing_time, ) except Exception as e: logger.error(f"Detection failed: {e}") raise RuntimeError(f"Object detection failed: {e}") from e
[docs] class Florence2Loader(DetectionModelLoader): """Loader for Florence-2 unified vision model. Florence-2 is a 230M parameter model that supports multiple vision tasks including object detection, captioning, and grounding. """
[docs] def load(self) -> None: """Load Florence-2 model with configured settings.""" try: from transformers import ( AutoModelForCausalLM, AutoProcessor, ) logger.info(f"Loading Florence-2 from {self.config.model_id}") self.processor = AutoProcessor.from_pretrained( self.config.model_id, cache_dir=str(self.config.cache_dir) if self.config.cache_dir else None, trust_remote_code=True, ) self.model = AutoModelForCausalLM.from_pretrained( self.config.model_id, cache_dir=str(self.config.cache_dir) if self.config.cache_dir else None, trust_remote_code=True, torch_dtype=torch.float16, ) if torch.cuda.is_available(): self.model.to(self.config.device) self.model.eval() logger.info("Florence-2 loaded successfully") except Exception as e: logger.error(f"Failed to load Florence-2: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
[docs] def detect( self, image: Image.Image, text_prompt: str, ) -> DetectionResult: """Detect objects using Florence-2 with text prompts.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") import time try: start_time = time.time() width, height = image.size task_prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{text_prompt}" inputs = self.processor(text=task_prompt, images=image, return_tensors="pt") inputs = {k: v.to(self.config.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=1024, num_beams=3, ) result = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] detections = self._parse_florence_output(result, width, height) processing_time = time.time() - start_time return DetectionResult( detections=detections, image_width=width, image_height=height, processing_time=processing_time, ) except Exception as e: logger.error(f"Detection failed: {e}") raise RuntimeError(f"Object detection failed: {e}") from e
def _parse_florence_output(self, result: str, width: int, height: int) -> list[Detection]: """Parse Florence-2 output format into Detection objects. Parameters ---------- result : str Model output string containing bounding boxes and labels. width : int Image width for normalization. height : int Image height for normalization. Returns ------- list[Detection] Parsed detections with normalized coordinates. """ detections = [] try: import json data = json.loads(result) if "bboxes" in data and "labels" in data: for bbox, label in zip(data["bboxes"], data["labels"], strict=False): x1, y1, x2, y2 = bbox normalized_bbox = BoundingBox( x1=float(x1) / width, y1=float(y1) / height, x2=float(x2) / width, y2=float(y2) / height, ) detections.append( Detection( bbox=normalized_bbox, confidence=1.0, label=label, ) ) except (json.JSONDecodeError, KeyError, ValueError) as e: logger.warning(f"Failed to parse Florence-2 output: {e}") return detections
[docs] def create_detection_loader(model_name: str, config: DetectionConfig) -> DetectionModelLoader: """Factory function to create appropriate detection loader based on model name. Parameters ---------- model_name : str Name of the model to load. Supported values: - "yolo-world-v2" or "yoloworld" - "grounding-dino-1-5" or "groundingdino" - "owlv2" or "owl-v2" - "florence-2" or "florence2" config : DetectionConfig Configuration for model loading and inference. Returns ------- DetectionModelLoader Appropriate loader instance for the specified model. Raises ------ ValueError If model_name is not recognized. """ model_name_lower = model_name.lower().replace("_", "-") if "yolo-world" in model_name_lower or "yoloworld" in model_name_lower: return YOLOWorldLoader(config) if "grounding-dino" in model_name_lower or "groundingdino" in model_name_lower: return GroundingDINOLoader(config) if "owl" in model_name_lower: return OWLv2Loader(config) if "florence" in model_name_lower: return Florence2Loader(config) raise ValueError( f"Unknown model name: {model_name}. Supported models: " "yolo-world-v2, grounding-dino-1-5, owlv2, florence-2" )