"""Tensor mapping and URL parsing utilities. Provides utilities for mapping tensor names between different formats, parsing model URLs, and handling architecture-specific conversions. Uses UK English spelling conventions throughout. """ from __future__ import annotations import re from typing import ClassVar from helpers.models.quantisation import ModelSource, URLType class TensorMapper: """Maps tensor names between HuggingFace and GGUF conventions. Provides flexible tensor name translation supporting direct mappings, layer-aware transformations, and architecture-specific overrides. Handles both simple renames and complex pattern-based conversions. """ # Common direct mappings across architectures DIRECT_MAPPINGS: ClassVar[dict[str, str]] = { "model.embed_tokens.weight": "token_embd.weight", "model.norm.weight": "output_norm.weight", "lm_head.weight": "output.weight", } # Layer component patterns for transformer blocks LAYER_PATTERNS: ClassVar[dict[str, str]] = { "self_attn.q_proj.weight": "attn_q.weight", "self_attn.q_proj.bias": "attn_q.bias", "self_attn.k_proj.weight": "attn_k.weight", "self_attn.k_proj.bias": "attn_k.bias", "self_attn.v_proj.weight": "attn_v.weight", "self_attn.v_proj.bias": "attn_v.bias", "self_attn.o_proj": "attn_output.weight", "mlp.gate_proj": "ffn_gate.weight", "mlp.up_proj": "ffn_up.weight", "mlp.down_proj": "ffn_down.weight", "input_layernorm": "attn_norm.weight", "post_attention_layernorm": "ffn_norm.weight", } @classmethod def map_tensor_name(cls, original_name: str) -> str | None: """Map original tensor name to GGUF format. Translates HuggingFace tensor naming to GGUF format, handling embeddings, attention layers, feed-forward networks, and normalisation layers. Uses layer-aware mapping for transformer blocks whilst maintaining consistency across different model architectures. Returns: GGUF tensor name, or None if unmappable. """ # Check direct mappings first if original_name in cls.DIRECT_MAPPINGS: return cls.DIRECT_MAPPINGS[original_name] # Handle layer-specific tensors if ".layers." in original_name: return cls._map_layer_tensor(original_name) # Return None for unmapped tensors return None @classmethod def _map_layer_tensor(cls, tensor_name: str) -> str | None: """Map layer-specific tensor names. Handles tensors within transformer layers, extracting layer indices and mapping component names to GGUF conventions. Supports attention projections, feed-forward networks, and normalisation layers. Returns: Mapped GGUF tensor name with layer index, or None if unmappable. """ # Extract layer number parts = tensor_name.split(".") layer_idx = None for i, part in enumerate(parts): if part == "layers" and i + 1 < len(parts): layer_idx = parts[i + 1] break if layer_idx is None: return None # Check each pattern for pattern, replacement in cls.LAYER_PATTERNS.items(): if pattern in tensor_name: return f"blk.{layer_idx}.{replacement}" return None class URLParser: """Parses and validates model URLs from various sources. Handles HuggingFace URLs, Ollama-style GGUF references, and other model source formats. Extracts metadata including author, model name, and file patterns for appropriate download strategies. """ @staticmethod def parse(url: str) -> ModelSource: """Parse URL and extract model source information. Analyses URL format to determine source type and extract relevant metadata for model download and processing. Supports both standard HuggingFace URLs and Ollama-style GGUF repository references. Returns: ModelSource with parsed metadata and appropriate source type. Raises: ValueError: If URL format is not recognised or supported. """ if not url: msg = "URL cannot be empty" raise ValueError(msg) # Try Ollama-style GGUF URL first (hf.co/author/model:pattern) ollama_match = re.match(r"^hf\.co/([^:]+):(.+)$", url) if ollama_match: source_model = ollama_match.group(1) gguf_pattern = ollama_match.group(2) return URLParser._create_model_source( url, URLType.OLLAMA_GGUF, source_model, gguf_file_pattern=gguf_pattern, is_gguf_repo=True, ) # Try regular HuggingFace URL hf_match = re.match(r"https://huggingface\.co/([^/]+/[^/?]+)", url) if hf_match: source_model = hf_match.group(1) return URLParser._create_model_source( url, URLType.HUGGINGFACE, source_model, is_gguf_repo=False ) msg = ( "Invalid URL format\n" "Supported formats:\n" " - https://huggingface.co/username/model-name\n" " - hf.co/username/model-name-GGUF:F16" ) raise ValueError(msg) @staticmethod def _create_model_source( url: str, url_type: URLType, source_model: str, gguf_file_pattern: str | None = None, is_gguf_repo: bool = False, ) -> ModelSource: """Create ModelSource with parsed information. Constructs a ModelSource instance with extracted metadata, handling author/model name splitting and GGUF suffix removal for repository names. Ensures consistent naming conventions across different source types. Returns: Configured ModelSource instance with normalised metadata. """ author, model_name = source_model.split("/", 1) # Strip -GGUF suffix for GGUF repos if is_gguf_repo and model_name.endswith("-GGUF"): model_name = model_name[:-5] return ModelSource( url=url, url_type=url_type, source_model=source_model, original_author=author, model_name=model_name, gguf_file_pattern=gguf_file_pattern, is_gguf_repo=is_gguf_repo, )