Source code for vlm_loader

"""Vision Language Model loader with support for multiple VLM architectures.

This module provides a unified interface for loading and running inference with
various Vision Language Models including Llama 4 Maverick, Gemma 3, InternVL3,
Pixtral Large, and Qwen2.5-VL. Models can be loaded with different quantization
strategies and inference frameworks (SGLang or vLLM).
"""

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any

import torch
from PIL import Image
from transformers import (
    AutoModel,
    AutoModelForVision2Seq,
    AutoProcessor,
    AutoTokenizer,
    BitsAndBytesConfig,
    Qwen2VLForConditionalGeneration,
)

logger = logging.getLogger(__name__)


[docs] class QuantizationType(str, Enum): """Supported quantization types for model compression.""" NONE = "none" FOUR_BIT = "4bit" EIGHT_BIT = "8bit" AWQ = "awq"
[docs] class InferenceFramework(str, Enum): """Supported inference frameworks for model execution.""" SGLANG = "sglang" VLLM = "vllm" TRANSFORMERS = "transformers"
[docs] @dataclass class VLMConfig: """Configuration for Vision Language Model loading and inference. Parameters ---------- model_id : str HuggingFace model identifier or local path. quantization : QuantizationType Quantization strategy to apply. framework : InferenceFramework Inference framework to use for model execution. max_memory_gb : int | None, default=None Maximum GPU memory to allocate in GB. If None, uses all available. device : str, default="cuda" Device to load the model on. trust_remote_code : bool, default=True Whether to trust remote code from HuggingFace. """ model_id: str quantization: QuantizationType = QuantizationType.FOUR_BIT framework: InferenceFramework = InferenceFramework.SGLANG max_memory_gb: int | None = None device: str = "cuda" trust_remote_code: bool = True
[docs] class VLMLoader(ABC): """Abstract base class for Vision Language Model loaders. All VLM loaders must implement the load and generate methods. """
[docs] def __init__(self, config: VLMConfig) -> None: """Initialize the VLM loader with configuration. Parameters ---------- config : VLMConfig Configuration for model loading and inference. """ self.config = config self.model = None self.processor = None self.tokenizer = None
[docs] @abstractmethod def load(self) -> None: """Load the model into memory with configured settings. Raises ------ RuntimeError If model loading fails. """ pass
[docs] @abstractmethod def generate( self, images: list[Image.Image], prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, ) -> str: """Generate text response from images and prompt. Parameters ---------- images : list[Image.Image] List of PIL images to process. prompt : str Text prompt for the model. max_new_tokens : int, default=512 Maximum number of tokens to generate. temperature : float, default=0.7 Sampling temperature for generation. Returns ------- str Generated text response. Raises ------ RuntimeError If generation fails or model is not loaded. """ pass
[docs] def unload(self) -> None: """Unload the model from memory to free GPU resources.""" if self.model is not None: del self.model self.model = None if self.processor is not None: # type: ignore[unreachable] del self.processor self.processor = None if self.tokenizer is not None: # type: ignore[unreachable] del self.tokenizer self.tokenizer = None if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Model unloaded and memory cleared")
def _get_quantization_config(self) -> Any: """Create quantization configuration for model loading. Returns ------- BitsAndBytesConfig | None Quantization config for bitsandbytes, or None if no quantization. """ if self.config.quantization == QuantizationType.FOUR_BIT: return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) if self.config.quantization == QuantizationType.EIGHT_BIT: return BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16, ) return None
[docs] class Llama4MaverickLoader(VLMLoader): """Loader for Llama 4 Maverick Vision Language Model. Llama 4 Maverick is a 400B parameter MoE model with 17B active parameters, supporting multimodal input with 10M context length. """
[docs] def load(self) -> None: """Load Llama 4 Maverick model with configured settings.""" try: logger.info( f"Loading Llama 4 Maverick from {self.config.model_id} " f"with {self.config.quantization} quantization" ) if self.config.framework == InferenceFramework.SGLANG: self._load_with_sglang() elif self.config.framework == InferenceFramework.VLLM: self._load_with_vllm() else: self._load_with_transformers() logger.info("Llama 4 Maverick loaded successfully") except Exception as e: logger.error(f"Failed to load Llama 4 Maverick: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
def _load_with_sglang(self) -> None: """Load model using SGLang framework for optimized inference.""" try: import sglang as sgl quantization_str = None if self.config.quantization == QuantizationType.FOUR_BIT: quantization_str = "bitsandbytes-4bit" elif self.config.quantization == QuantizationType.AWQ: quantization_str = "awq" runtime = sgl.Runtime( model_path=self.config.model_id, tokenizer_path=self.config.model_id, quantization=quantization_str, trust_remote_code=self.config.trust_remote_code, mem_fraction_static=0.8 if self.config.max_memory_gb else 0.9, ) self.model = runtime logger.info("Model loaded with SGLang") except ImportError: logger.warning("SGLang not available, falling back to transformers") self._load_with_transformers() except Exception as e: logger.error(f"SGLang loading failed: {e}") raise def _load_with_vllm(self) -> None: """Load model using vLLM framework for high-throughput inference.""" try: from vllm import LLM quantization_str = None if self.config.quantization == QuantizationType.FOUR_BIT: quantization_str = "bitsandbytes" elif self.config.quantization == QuantizationType.AWQ: quantization_str = "awq" self.model = LLM( model=self.config.model_id, quantization=quantization_str, trust_remote_code=self.config.trust_remote_code, gpu_memory_utilization=0.9, ) logger.info("Model loaded with vLLM") except ImportError: logger.warning("vLLM not available, falling back to transformers") self._load_with_transformers() except Exception as e: logger.error(f"vLLM loading failed: {e}") raise def _load_with_transformers(self) -> None: """Load model using HuggingFace Transformers library.""" quantization_config = self._get_quantization_config() self.processor = AutoProcessor.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.model = AutoModelForVision2Seq.from_pretrained( self.config.model_id, quantization_config=quantization_config, device_map="auto", trust_remote_code=self.config.trust_remote_code, torch_dtype=torch.bfloat16, ) logger.info("Model loaded with Transformers")
[docs] def generate( self, images: list[Image.Image], prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, ) -> str: """Generate text response from images and prompt using Llama 4 Maverick.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") try: if self.config.framework == InferenceFramework.SGLANG: return self._generate_with_sglang(images, prompt, max_new_tokens, temperature) if self.config.framework == InferenceFramework.VLLM: return self._generate_with_vllm(images, prompt, max_new_tokens, temperature) return self._generate_with_transformers(images, prompt, max_new_tokens, temperature) except Exception as e: logger.error(f"Generation failed: {e}") raise RuntimeError(f"Text generation failed: {e}") from e
def _generate_with_sglang( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using SGLang runtime.""" import sglang as sgl @sgl.function # type: ignore[misc] def image_qa(s: Any, images: Any, prompt: Any) -> None: for img in images: s += sgl.image(img) s += prompt s += sgl.gen("answer", max_tokens=max_new_tokens, temperature=temperature) state = image_qa.run(images=images, prompt=prompt, backend=self.model) return str(state["answer"]) # type: ignore[no-any-return] def _generate_with_vllm( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using vLLM engine.""" from vllm import SamplingParams sampling_params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature) # vLLM expects images in specific format outputs = self.model.generate( # type: ignore[attr-defined] {"prompt": prompt, "multi_modal_data": {"image": images}}, sampling_params=sampling_params, ) return outputs[0].outputs[0].text # type: ignore[no-any-return] def _generate_with_transformers( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using HuggingFace Transformers.""" if self.processor is None or self.tokenizer is None: raise RuntimeError("Processor and tokenizer not initialized") inputs = self.processor(images=images, text=prompt, return_tensors="pt") inputs = {k: v.to(self.config.device) for k, v in inputs.items()} with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
[docs] class Gemma3Loader(VLMLoader): """Loader for Gemma 3 27B Vision Language Model. Gemma 3 27B excels at document analysis, OCR, and multilingual tasks with fast inference speed. """
[docs] def load(self) -> None: """Load Gemma 3 model with configured settings.""" try: logger.info( f"Loading Gemma 3 from {self.config.model_id} " f"with {self.config.quantization} quantization" ) if self.config.framework == InferenceFramework.SGLANG: self._load_with_sglang() elif self.config.framework == InferenceFramework.VLLM: self._load_with_vllm() else: self._load_with_transformers() logger.info("Gemma 3 loaded successfully") except Exception as e: logger.error(f"Failed to load Gemma 3: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
def _load_with_sglang(self) -> None: """Load model using SGLang framework.""" try: import sglang as sgl quantization_str = None if self.config.quantization == QuantizationType.FOUR_BIT: quantization_str = "bitsandbytes-4bit" runtime = sgl.Runtime( model_path=self.config.model_id, tokenizer_path=self.config.model_id, quantization=quantization_str, trust_remote_code=self.config.trust_remote_code, mem_fraction_static=0.8 if self.config.max_memory_gb else 0.9, ) self.model = runtime logger.info("Model loaded with SGLang") except ImportError: logger.warning("SGLang not available, falling back to transformers") self._load_with_transformers() def _load_with_vllm(self) -> None: """Load model using vLLM framework.""" try: from vllm import LLM quantization_str = None if self.config.quantization == QuantizationType.FOUR_BIT: quantization_str = "bitsandbytes" self.model = LLM( model=self.config.model_id, quantization=quantization_str, trust_remote_code=self.config.trust_remote_code, gpu_memory_utilization=0.9, ) logger.info("Model loaded with vLLM") except ImportError: logger.warning("vLLM not available, falling back to transformers") self._load_with_transformers() def _load_with_transformers(self) -> None: """Load model using HuggingFace Transformers.""" from transformers import AutoModelForVision2Seq quantization_config = self._get_quantization_config() self.processor = AutoProcessor.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.model = AutoModelForVision2Seq.from_pretrained( self.config.model_id, quantization_config=quantization_config, device_map="auto", trust_remote_code=self.config.trust_remote_code, torch_dtype=torch.bfloat16, ) logger.info("Model loaded with Transformers")
[docs] def generate( self, images: list[Image.Image], prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, ) -> str: """Generate text response from images and prompt using Gemma 3.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") try: if self.config.framework == InferenceFramework.SGLANG: return self._generate_with_sglang(images, prompt, max_new_tokens, temperature) if self.config.framework == InferenceFramework.VLLM: return self._generate_with_vllm(images, prompt, max_new_tokens, temperature) return self._generate_with_transformers(images, prompt, max_new_tokens, temperature) except Exception as e: logger.error(f"Generation failed: {e}") raise RuntimeError(f"Text generation failed: {e}") from e
def _generate_with_sglang( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using SGLang runtime.""" import sglang as sgl @sgl.function # type: ignore[misc] def image_qa(s: Any, images: Any, prompt: Any) -> None: for img in images: s += sgl.image(img) s += prompt s += sgl.gen("answer", max_tokens=max_new_tokens, temperature=temperature) state = image_qa.run(images=images, prompt=prompt, backend=self.model) return str(state["answer"]) # type: ignore[no-any-return] def _generate_with_vllm( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using vLLM engine.""" from vllm import SamplingParams sampling_params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature) outputs = self.model.generate( # type: ignore[attr-defined] {"prompt": prompt, "multi_modal_data": {"image": images}}, sampling_params=sampling_params, ) return outputs[0].outputs[0].text # type: ignore[no-any-return] def _generate_with_transformers( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using HuggingFace Transformers.""" if self.processor is None or self.tokenizer is None: raise RuntimeError("Processor and tokenizer not initialized") inputs = self.processor(images=images, text=prompt, return_tensors="pt") inputs = {k: v.to(self.config.device) for k, v in inputs.items()} with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
[docs] class InternVL3Loader(VLMLoader): """Loader for InternVL3-78B Vision Language Model. InternVL3-78B achieves state-of-the-art results on vision benchmarks with strong scientific reasoning capabilities. """
[docs] def load(self) -> None: """Load InternVL3 model with configured settings.""" try: logger.info( f"Loading InternVL3 from {self.config.model_id} " f"with {self.config.quantization} quantization" ) self._load_with_transformers() logger.info("InternVL3 loaded successfully") except Exception as e: logger.error(f"Failed to load InternVL3: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
def _load_with_transformers(self) -> None: """Load model using HuggingFace Transformers.""" quantization_config = self._get_quantization_config() self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.model = AutoModel.from_pretrained( self.config.model_id, quantization_config=quantization_config, device_map="auto", trust_remote_code=self.config.trust_remote_code, torch_dtype=torch.bfloat16, ) logger.info("Model loaded with Transformers")
[docs] def generate( self, images: list[Image.Image], prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, ) -> str: """Generate text response from images and prompt using InternVL3.""" if self.model is None or self.tokenizer is None: raise RuntimeError("Model not loaded. Call load() first.") try: pixel_values_list = [] for image in images: pixel_values = self.model.load_image(image, max_num=12).to(torch.bfloat16).cuda() pixel_values_list.append(pixel_values) generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature, "do_sample": True, } return str( self.model.chat( self.tokenizer, pixel_values_list[0] if len(pixel_values_list) == 1 else pixel_values_list, prompt, generation_config, ) ) except Exception as e: logger.error(f"Generation failed: {e}") raise RuntimeError(f"Text generation failed: {e}") from e
[docs] class PixtralLargeLoader(VLMLoader): """Loader for Pixtral Large Vision Language Model. Pixtral Large is a 123B parameter model with 128k context length, optimized for batch processing of long documents. """
[docs] def load(self) -> None: """Load Pixtral Large model with configured settings.""" try: logger.info( f"Loading Pixtral Large from {self.config.model_id} " f"with {self.config.quantization} quantization" ) if self.config.framework == InferenceFramework.VLLM: self._load_with_vllm() else: self._load_with_transformers() logger.info("Pixtral Large loaded successfully") except Exception as e: logger.error(f"Failed to load Pixtral Large: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
def _load_with_vllm(self) -> None: """Load model using vLLM framework.""" try: from vllm import LLM quantization_str = None if self.config.quantization == QuantizationType.FOUR_BIT: quantization_str = "bitsandbytes" elif self.config.quantization == QuantizationType.AWQ: quantization_str = "awq" self.model = LLM( model=self.config.model_id, quantization=quantization_str, trust_remote_code=self.config.trust_remote_code, gpu_memory_utilization=0.9, ) logger.info("Model loaded with vLLM") except ImportError: logger.warning("vLLM not available, falling back to transformers") self._load_with_transformers() def _load_with_transformers(self) -> None: """Load model using HuggingFace Transformers.""" from transformers import AutoModelForVision2Seq quantization_config = self._get_quantization_config() self.processor = AutoProcessor.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.model = AutoModelForVision2Seq.from_pretrained( self.config.model_id, quantization_config=quantization_config, device_map="auto", trust_remote_code=self.config.trust_remote_code, torch_dtype=torch.bfloat16, ) logger.info("Model loaded with Transformers")
[docs] def generate( self, images: list[Image.Image], prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, ) -> str: """Generate text response from images and prompt using Pixtral Large.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") try: if self.config.framework == InferenceFramework.VLLM: return self._generate_with_vllm(images, prompt, max_new_tokens, temperature) return self._generate_with_transformers(images, prompt, max_new_tokens, temperature) except Exception as e: logger.error(f"Generation failed: {e}") raise RuntimeError(f"Text generation failed: {e}") from e
def _generate_with_vllm( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using vLLM engine.""" from vllm import SamplingParams sampling_params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature) outputs = self.model.generate( # type: ignore[attr-defined] {"prompt": prompt, "multi_modal_data": {"image": images}}, sampling_params=sampling_params, ) return outputs[0].outputs[0].text # type: ignore[no-any-return] def _generate_with_transformers( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using HuggingFace Transformers.""" if self.processor is None or self.tokenizer is None: raise RuntimeError("Processor and tokenizer not initialized") inputs = self.processor(images=images, text=prompt, return_tensors="pt") inputs = {k: v.to(self.config.device) for k, v in inputs.items()} with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
[docs] class Qwen25VLLoader(VLMLoader): """Loader for Qwen2.5-VL 72B Vision Language Model. Qwen2.5-VL 72B is a proven stable model with strong performance across vision-language tasks. """
[docs] def load(self) -> None: """Load Qwen2.5-VL model with configured settings.""" try: logger.info( f"Loading Qwen2.5-VL from {self.config.model_id} " f"with {self.config.quantization} quantization" ) if self.config.framework == InferenceFramework.SGLANG: self._load_with_sglang() elif self.config.framework == InferenceFramework.VLLM: self._load_with_vllm() else: self._load_with_transformers() logger.info("Qwen2.5-VL loaded successfully") except Exception as e: logger.error(f"Failed to load Qwen2.5-VL: {e}") raise RuntimeError(f"Model loading failed: {e}") from e
def _load_with_sglang(self) -> None: """Load model using SGLang framework.""" try: import sglang as sgl quantization_str = None if self.config.quantization == QuantizationType.FOUR_BIT: quantization_str = "bitsandbytes-4bit" elif self.config.quantization == QuantizationType.AWQ: quantization_str = "awq" runtime = sgl.Runtime( model_path=self.config.model_id, tokenizer_path=self.config.model_id, quantization=quantization_str, trust_remote_code=self.config.trust_remote_code, mem_fraction_static=0.8 if self.config.max_memory_gb else 0.9, ) self.model = runtime logger.info("Model loaded with SGLang") except ImportError: logger.warning("SGLang not available, falling back to transformers") self._load_with_transformers() def _load_with_vllm(self) -> None: """Load model using vLLM framework.""" try: from vllm import LLM quantization_str = None if self.config.quantization == QuantizationType.FOUR_BIT: quantization_str = "bitsandbytes" elif self.config.quantization == QuantizationType.AWQ: quantization_str = "awq" self.model = LLM( model=self.config.model_id, quantization=quantization_str, trust_remote_code=self.config.trust_remote_code, gpu_memory_utilization=0.9, ) logger.info("Model loaded with vLLM") except ImportError: logger.warning("vLLM not available, falling back to transformers") self._load_with_transformers() def _load_with_transformers(self) -> None: """Load model using HuggingFace Transformers.""" quantization_config = self._get_quantization_config() self.processor = AutoProcessor.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_id, trust_remote_code=self.config.trust_remote_code ) self.model = Qwen2VLForConditionalGeneration.from_pretrained( self.config.model_id, quantization_config=quantization_config, device_map="auto", trust_remote_code=self.config.trust_remote_code, torch_dtype=torch.bfloat16, ) logger.info("Model loaded with Transformers")
[docs] def generate( self, images: list[Image.Image], prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, ) -> str: """Generate text response from images and prompt using Qwen2.5-VL.""" if self.model is None: raise RuntimeError("Model not loaded. Call load() first.") try: if self.config.framework == InferenceFramework.SGLANG: return self._generate_with_sglang(images, prompt, max_new_tokens, temperature) if self.config.framework == InferenceFramework.VLLM: return self._generate_with_vllm(images, prompt, max_new_tokens, temperature) return self._generate_with_transformers(images, prompt, max_new_tokens, temperature) except Exception as e: logger.error(f"Generation failed: {e}") raise RuntimeError(f"Text generation failed: {e}") from e
def _generate_with_sglang( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using SGLang runtime.""" import sglang as sgl @sgl.function # type: ignore[misc] def image_qa(s: Any, images: Any, prompt: Any) -> None: for img in images: s += sgl.image(img) s += prompt s += sgl.gen("answer", max_tokens=max_new_tokens, temperature=temperature) state = image_qa.run(images=images, prompt=prompt, backend=self.model) return str(state["answer"]) # type: ignore[no-any-return] def _generate_with_vllm( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using vLLM engine.""" from vllm import SamplingParams sampling_params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature) outputs = self.model.generate( # type: ignore[attr-defined] {"prompt": prompt, "multi_modal_data": {"image": images}}, sampling_params=sampling_params, ) return outputs[0].outputs[0].text # type: ignore[no-any-return] def _generate_with_transformers( self, images: list[Image.Image], prompt: str, max_new_tokens: int, temperature: float, ) -> str: """Generate using HuggingFace Transformers.""" if self.processor is None or self.tokenizer is None: raise RuntimeError("Processor and tokenizer not initialized") messages = [ { "role": "user", "content": [ *[{"type": "image", "image": img} for img in images], {"type": "text", "text": prompt}, ], } ] text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = self.processor.process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = {k: v.to(self.config.device) for k, v in inputs.items()} with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
[docs] def create_vlm_loader(model_name: str, config: VLMConfig) -> VLMLoader: """Factory function to create appropriate VLM loader based on model name. Parameters ---------- model_name : str Name of the model to load. Supported values: - "llama-4-maverick" or "llama4-maverick" - "gemma-3-27b" or "gemma3" - "internvl3-78b" or "internvl3" - "pixtral-large" or "pixtral" - "qwen2.5-vl-72b" or "qwen25vl" config : VLMConfig Configuration for model loading and inference. Returns ------- VLMLoader Appropriate loader instance for the specified model. Raises ------ ValueError If model_name is not recognized. """ model_name_lower = model_name.lower().replace("_", "-") if "llama-4-maverick" in model_name_lower or "llama4-maverick" in model_name_lower: return Llama4MaverickLoader(config) if "gemma-3" in model_name_lower or "gemma3" in model_name_lower: return Gemma3Loader(config) if "internvl3" in model_name_lower: return InternVL3Loader(config) if "pixtral" in model_name_lower: return PixtralLargeLoader(config) if "qwen2.5-vl" in model_name_lower or "qwen25vl" in model_name_lower: return Qwen25VLLoader(config) raise ValueError( f"Unknown model name: {model_name}. Supported models: " "llama-4-maverick, gemma-3-27b, internvl3-78b, pixtral-large, qwen2.5-vl-72b" )