Source code for model_manager

"""Model management with dynamic loading, memory budget validation, and LRU eviction.

This module provides a ModelManager class that handles loading and unloading
of AI models based on available GPU memory. Models are loaded on demand and
automatically evicted when memory pressure occurs.
"""

import logging
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any

import torch
import yaml
from opentelemetry import trace

logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)


[docs] class ModelConfig: """Configuration for a single model variant. Attributes ---------- model_id : str Hugging Face model identifier. framework : str Inference framework (sglang, vllm, pytorch). vram_gb : float VRAM requirement in GB. quantization : str | None Quantization method (4bit, 8bit, awq, etc). speed : str Speed category (fast, medium, slow). description : str Human-readable description. fps : int | None Processing speed in frames per second (for vision models). """
[docs] def __init__(self, config_dict: dict[str, Any]) -> None: """Initialize model configuration from dictionary. Parameters ---------- config_dict : dict[str, Any] Dictionary containing model configuration parameters. """ self.model_id: str = config_dict["model_id"] self.framework: str = config_dict["framework"] self.vram_gb: float = config_dict.get("vram_gb", 0) self.quantization: str | None = config_dict.get("quantization") self.speed: str = config_dict.get("speed", "medium") self.description: str = config_dict.get("description", "") self.fps: int | None = config_dict.get("fps")
@property def vram_bytes(self) -> int: """Convert VRAM requirement from GB to bytes. Returns ------- int VRAM requirement in bytes. """ return int(self.vram_gb * 1024 * 1024 * 1024)
[docs] class TaskConfig: """Configuration for a task type with multiple model options. Attributes ---------- task_name : str Name of the task. selected : str Currently selected model name. options : dict[str, ModelConfig] Available model options for this task. """
[docs] def __init__(self, task_name: str, config_dict: dict[str, Any]) -> None: """Initialize task configuration from dictionary. Parameters ---------- task_name : str Name of the task (e.g., "video_summarization"). config_dict : dict[str, Any] Dictionary containing task configuration. """ self.task_name = task_name self.selected = config_dict["selected"] self.options: dict[str, ModelConfig] = { name: ModelConfig(opt_dict) for name, opt_dict in config_dict["options"].items() }
[docs] def get_selected_config(self) -> ModelConfig: """Get the currently selected model configuration. Returns ------- ModelConfig Configuration for the selected model. """ return self.options[self.selected]
[docs] class InferenceConfig: """Global inference configuration settings. Attributes ---------- max_memory_per_model : str Maximum memory per model ('auto' or specific value). offload_threshold : float Memory usage threshold for offloading (0.0 to 1.0). warmup_on_startup : bool Whether to load all models on startup. default_batch_size : int Default batch size for inference. max_batch_size : int Maximum batch size for inference. """
[docs] def __init__(self, config_dict: dict[str, Any]) -> None: """Initialize inference configuration from dictionary. Parameters ---------- config_dict : dict[str, Any] Dictionary containing inference configuration. """ self.max_memory_per_model = config_dict.get("max_memory_per_model", "auto") self.offload_threshold: float = config_dict.get("offload_threshold", 0.85) self.warmup_on_startup: bool = config_dict.get("warmup_on_startup", False) self.default_batch_size: int = config_dict.get("default_batch_size", 1) self.max_batch_size: int = config_dict.get("max_batch_size", 8)
[docs] class ModelManager: """Manages loading, unloading, and memory management of AI models. This class handles dynamic model loading based on memory availability, implements LRU eviction when memory pressure occurs, and provides utilities for VRAM monitoring. Attributes ---------- config_path : Path Path to models.yaml configuration file. config : dict[str, Any] Parsed configuration dictionary. loaded_models : OrderedDict[str, Any] Currently loaded models (LRU ordered). model_load_times : dict[str, float] Timestamp when each model was loaded. model_memory_usage : dict[str, int] Actual memory usage per model in bytes. tasks : dict[str, TaskConfig] Task configurations. inference_config : InferenceConfig Global inference settings. """
[docs] def __init__(self, config_path: str) -> None: """Initialize ModelManager with configuration file. Parameters ---------- config_path : str Path to models.yaml configuration file. """ self.config_path = Path(config_path) self.config = self._load_config() self.loaded_models: OrderedDict[str, Any] = OrderedDict() self.model_load_times: dict[str, float] = {} self.model_memory_usage: dict[str, int] = {} logger.info(f"ModelManager initialized with config from {config_path}")
def _load_config(self) -> dict[str, Any]: """Load configuration from YAML file. Returns ------- dict[str, Any] Dictionary containing parsed configuration. Raises ------ FileNotFoundError If configuration file does not exist. yaml.YAMLError If configuration file is invalid. """ if not self.config_path.exists(): raise FileNotFoundError(f"Config file not found: {self.config_path}") with self.config_path.open() as f: config: dict[str, Any] = yaml.safe_load(f) self.tasks: dict[str, TaskConfig] = { task_name: TaskConfig(task_name, task_config) for task_name, task_config in config["models"].items() } self.inference_config = InferenceConfig(config["inference"]) return config
[docs] def get_available_vram(self) -> int: """ Get available GPU memory in bytes. Returns: Available VRAM in bytes """ if not torch.cuda.is_available(): return 0 device = torch.cuda.current_device() total = torch.cuda.get_device_properties(device).total_memory allocated = torch.cuda.memory_allocated(device) return total - allocated
[docs] def get_total_vram(self) -> int: """ Get total GPU memory in bytes. Returns: Total VRAM in bytes """ if not torch.cuda.is_available(): return 0 device = torch.cuda.current_device() return torch.cuda.get_device_properties(device).total_memory
[docs] def get_memory_usage_percentage(self) -> float: """ Get current GPU memory usage as percentage. Returns: Memory usage percentage (0.0 to 1.0) """ total = self.get_total_vram() if total == 0: return 0.0 allocated = torch.cuda.memory_allocated() return allocated / total
[docs] def check_memory_available(self, required_bytes: int) -> bool: """ Check if sufficient memory is available for model loading. Args: required_bytes: Required memory in bytes Returns: True if sufficient memory is available """ available = self.get_available_vram() return available >= required_bytes
[docs] def get_lru_model(self) -> str | None: """ Get least recently used model identifier. Returns: Task name of LRU model, or None if no models loaded """ if not self.loaded_models: return None return next(iter(self.loaded_models))
[docs] @tracer.start_as_current_span("evict_lru_model") async def evict_lru_model(self) -> str | None: """ Evict the least recently used model from memory. Returns: Task name of evicted model, or None if no models to evict """ lru_task = self.get_lru_model() if lru_task is None: logger.warning("No models to evict") return None logger.info(f"Evicting LRU model: {lru_task}") await self.unload_model(lru_task) return lru_task
[docs] @tracer.start_as_current_span("unload_model") async def unload_model(self, task_type: str) -> None: """ Unload a model from memory. Args: task_type: Task type of model to unload """ if task_type not in self.loaded_models: logger.warning(f"Model {task_type} not loaded") return logger.info(f"Unloading model: {task_type}") del self.loaded_models[task_type] del self.model_load_times[task_type] del self.model_memory_usage[task_type] if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"Model {task_type} unloaded successfully")
[docs] @tracer.start_as_current_span("load_model") async def load_model(self, task_type: str) -> Any: """ Load a model for the specified task type. This method loads the selected model for the task, handling memory management and eviction if necessary. Args: task_type: Task type to load model for Returns: Loaded model object Raises: ValueError: If task type is invalid or model cannot be loaded RuntimeError: If insufficient memory after eviction attempts """ if task_type not in self.tasks: raise ValueError(f"Invalid task type: {task_type}") if task_type in self.loaded_models: self.loaded_models.move_to_end(task_type) logger.info(f"Model {task_type} already loaded, moved to end") return self.loaded_models[task_type] task_config = self.tasks[task_type] model_config = task_config.get_selected_config() logger.info( f"Loading model for {task_type}: {model_config.model_id} " f"({model_config.vram_gb}GB VRAM required)" ) while not self.check_memory_available(model_config.vram_bytes): memory_usage = self.get_memory_usage_percentage() logger.info(f"Insufficient memory (usage: {memory_usage:.1%}), evicting LRU model") evicted = await self.evict_lru_model() if evicted is None: raise RuntimeError(f"Insufficient memory for {task_type} and no models to evict") memory_before = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0 model = await self._load_model_implementation(task_type, model_config) memory_after = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0 actual_memory = memory_after - memory_before self.loaded_models[task_type] = model self.model_load_times[task_type] = time.time() self.model_memory_usage[task_type] = actual_memory logger.info( f"Model {task_type} loaded successfully " f"(actual memory: {actual_memory / 1024**3:.2f}GB)" ) return model
async def _load_model_implementation(self, task_type: str, model_config: ModelConfig) -> Any: """ Load model implementation based on framework. This is a placeholder that will be replaced with actual model loading logic when model loaders are implemented in Phase 3. Args: task_type: Task type being loaded model_config: Model configuration Returns: Loaded model object """ logger.info(f"Loading {model_config.framework} model: {model_config.model_id}") return { "task_type": task_type, "model_id": model_config.model_id, "framework": model_config.framework, "config": model_config, }
[docs] async def get_model(self, task_type: str) -> Any: """ Get model for task type, loading if necessary. Args: task_type: Task type to get model for Returns: Loaded model object """ if task_type in self.loaded_models: self.loaded_models.move_to_end(task_type) return self.loaded_models[task_type] return await self.load_model(task_type)
[docs] def get_loaded_models(self) -> dict[str, dict[str, Any]]: """ Get information about currently loaded models. Returns: Dictionary mapping task types to model information """ result = {} for task_type in self.loaded_models: result[task_type] = { "model_id": self.tasks[task_type].get_selected_config().model_id, "memory_usage_gb": self.model_memory_usage.get(task_type, 0) / 1024**3, "load_time": self.model_load_times.get(task_type), } return result
[docs] def get_model_config(self, task_type: str) -> TaskConfig | None: """ Get configuration for a task type. Args: task_type: Task type to get configuration for Returns: Task configuration, or None if task type is invalid """ return self.tasks.get(task_type)
[docs] async def set_selected_model(self, task_type: str, model_name: str) -> None: """ Change the selected model for a task type. If the task's model is currently loaded, it will be unloaded and the new model will be loaded. Args: task_type: Task type to update model_name: Name of model option to select Raises: ValueError: If task type or model name is invalid """ if task_type not in self.tasks: raise ValueError(f"Invalid task type: {task_type}") task_config = self.tasks[task_type] if model_name not in task_config.options: raise ValueError(f"Invalid model name: {model_name} for task {task_type}") old_selection = task_config.selected task_config.selected = model_name self.config["models"][task_type]["selected"] = model_name logger.info(f"Changed {task_type} model from {old_selection} to {model_name}") if task_type in self.loaded_models: await self.unload_model(task_type) await self.load_model(task_type)
[docs] def validate_memory_budget(self) -> dict[str, Any]: """ Validate that all selected models can fit in available memory. Returns: Dictionary with validation results """ total_vram = self.get_total_vram() total_required = 0 model_requirements = {} for task_type, task_config in self.tasks.items(): model_config = task_config.get_selected_config() model_requirements[task_type] = { "model_id": model_config.model_id, "vram_gb": model_config.vram_gb, } total_required += model_config.vram_bytes threshold = self.inference_config.offload_threshold max_allowed = int(total_vram * threshold) return { "valid": total_required <= max_allowed, "total_vram_gb": total_vram / 1024**3, "total_required_gb": total_required / 1024**3, "threshold": threshold, "max_allowed_gb": max_allowed / 1024**3, "model_requirements": model_requirements, }
[docs] async def warmup_models(self) -> None: """Load all selected models if warmup_on_startup is enabled.""" if not self.inference_config.warmup_on_startup: logger.info("Warmup disabled, skipping model loading") return logger.info("Warming up all selected models") for task_type in self.tasks: try: await self.load_model(task_type) except Exception as e: logger.error(f"Failed to warmup {task_type}: {e}")
[docs] async def shutdown(self) -> None: """Unload all models and clean up resources.""" logger.info("Shutting down ModelManager") for task_type in list(self.loaded_models.keys()): await self.unload_model(task_type)