llm-gguf-tools/helpers/utils/tensor_mapping.py

186 lines
6.4 KiB
Python

"""Tensor mapping and URL parsing utilities.
Provides utilities for mapping tensor names between different formats,
parsing model URLs, and handling architecture-specific conversions.
Uses UK English spelling conventions throughout.
"""
from __future__ import annotations
import re
from typing import ClassVar
from helpers.models.quantisation import ModelSource, URLType
class TensorMapper:
"""Maps tensor names between HuggingFace and GGUF conventions.
Provides flexible tensor name translation supporting direct mappings,
layer-aware transformations, and architecture-specific overrides.
Handles both simple renames and complex pattern-based conversions.
"""
# Common direct mappings across architectures
DIRECT_MAPPINGS: ClassVar[dict[str, str]] = {
"model.embed_tokens.weight": "token_embd.weight",
"model.norm.weight": "output_norm.weight",
"lm_head.weight": "output.weight",
}
# Layer component patterns for transformer blocks
LAYER_PATTERNS: ClassVar[dict[str, str]] = {
"self_attn.q_proj.weight": "attn_q.weight",
"self_attn.q_proj.bias": "attn_q.bias",
"self_attn.k_proj.weight": "attn_k.weight",
"self_attn.k_proj.bias": "attn_k.bias",
"self_attn.v_proj.weight": "attn_v.weight",
"self_attn.v_proj.bias": "attn_v.bias",
"self_attn.o_proj": "attn_output.weight",
"mlp.gate_proj": "ffn_gate.weight",
"mlp.up_proj": "ffn_up.weight",
"mlp.down_proj": "ffn_down.weight",
"input_layernorm": "attn_norm.weight",
"post_attention_layernorm": "ffn_norm.weight",
}
@classmethod
def map_tensor_name(cls, original_name: str) -> str | None:
"""Map original tensor name to GGUF format.
Translates HuggingFace tensor naming to GGUF format, handling embeddings,
attention layers, feed-forward networks, and normalisation layers. Uses
layer-aware mapping for transformer blocks whilst maintaining consistency
across different model architectures.
Returns:
GGUF tensor name, or None if unmappable.
"""
# Check direct mappings first
if original_name in cls.DIRECT_MAPPINGS:
return cls.DIRECT_MAPPINGS[original_name]
# Handle layer-specific tensors
if ".layers." in original_name:
return cls._map_layer_tensor(original_name)
# Return None for unmapped tensors
return None
@classmethod
def _map_layer_tensor(cls, tensor_name: str) -> str | None:
"""Map layer-specific tensor names.
Handles tensors within transformer layers, extracting layer indices
and mapping component names to GGUF conventions. Supports attention
projections, feed-forward networks, and normalisation layers.
Returns:
Mapped GGUF tensor name with layer index, or None if unmappable.
"""
# Extract layer number
parts = tensor_name.split(".")
layer_idx = None
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts):
layer_idx = parts[i + 1]
break
if layer_idx is None:
return None
# Check each pattern
for pattern, replacement in cls.LAYER_PATTERNS.items():
if pattern in tensor_name:
return f"blk.{layer_idx}.{replacement}"
return None
class URLParser:
"""Parses and validates model URLs from various sources.
Handles HuggingFace URLs, Ollama-style GGUF references, and other
model source formats. Extracts metadata including author, model name,
and file patterns for appropriate download strategies.
"""
@staticmethod
def parse(url: str) -> ModelSource:
"""Parse URL and extract model source information.
Analyses URL format to determine source type and extract relevant
metadata for model download and processing. Supports both standard
HuggingFace URLs and Ollama-style GGUF repository references.
Returns:
ModelSource with parsed metadata and appropriate source type.
Raises:
ValueError: If URL format is not recognised or supported.
"""
if not url:
msg = "URL cannot be empty"
raise ValueError(msg)
# Try Ollama-style GGUF URL first (hf.co/author/model:pattern)
ollama_match = re.match(r"^hf\.co/([^:]+):(.+)$", url)
if ollama_match:
source_model = ollama_match.group(1)
gguf_pattern = ollama_match.group(2)
return URLParser._create_model_source(
url,
URLType.OLLAMA_GGUF,
source_model,
gguf_file_pattern=gguf_pattern,
is_gguf_repo=True,
)
# Try regular HuggingFace URL
hf_match = re.match(r"https://huggingface\.co/([^/]+/[^/?]+)", url)
if hf_match:
source_model = hf_match.group(1)
return URLParser._create_model_source(
url, URLType.HUGGINGFACE, source_model, is_gguf_repo=False
)
msg = (
"Invalid URL format\n"
"Supported formats:\n"
" - https://huggingface.co/username/model-name\n"
" - hf.co/username/model-name-GGUF:F16"
)
raise ValueError(msg)
@staticmethod
def _create_model_source(
url: str,
url_type: URLType,
source_model: str,
gguf_file_pattern: str | None = None,
is_gguf_repo: bool = False,
) -> ModelSource:
"""Create ModelSource with parsed information.
Constructs a ModelSource instance with extracted metadata, handling
author/model name splitting and GGUF suffix removal for repository names.
Ensures consistent naming conventions across different source types.
Returns:
Configured ModelSource instance with normalised metadata.
"""
author, model_name = source_model.split("/", 1)
# Strip -GGUF suffix for GGUF repos
if is_gguf_repo and model_name.endswith("-GGUF"):
model_name = model_name[:-5]
return ModelSource(
url=url,
url_type=url_type,
source_model=source_model,
original_author=author,
model_name=model_name,
gguf_file_pattern=gguf_file_pattern,
is_gguf_repo=is_gguf_repo,
)