196 lines
6.5 KiB
Python
196 lines
6.5 KiB
Python
"""Tensor mapping and URL parsing utilities.
|
|
|
|
Provides utilities for mapping tensor names between different formats,
|
|
parsing model URLs, and handling architecture-specific conversions.
|
|
Uses UK English spelling conventions throughout.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import ClassVar
|
|
|
|
from helpers.models.quantisation import ModelSource, URLType
|
|
|
|
|
|
class TensorMapper:
|
|
"""Maps tensor names between HuggingFace and GGUF conventions.
|
|
|
|
Provides flexible tensor name translation supporting direct mappings,
|
|
layer-aware transformations, and architecture-specific overrides.
|
|
Handles both simple renames and complex pattern-based conversions.
|
|
"""
|
|
|
|
# Common direct mappings across architectures
|
|
DIRECT_MAPPINGS: ClassVar[dict[str, str]] = {
|
|
"model.embed_tokens.weight": "token_embd.weight",
|
|
"model.norm.weight": "output_norm.weight",
|
|
"lm_head.weight": "output.weight",
|
|
}
|
|
|
|
# Layer component patterns for transformer blocks
|
|
LAYER_PATTERNS: ClassVar[dict[str, str]] = {
|
|
"self_attn.q_proj.weight": "attn_q.weight",
|
|
"self_attn.q_proj.bias": "attn_q.bias",
|
|
"self_attn.k_proj.weight": "attn_k.weight",
|
|
"self_attn.k_proj.bias": "attn_k.bias",
|
|
"self_attn.v_proj.weight": "attn_v.weight",
|
|
"self_attn.v_proj.bias": "attn_v.bias",
|
|
"self_attn.o_proj": "attn_output.weight",
|
|
"mlp.gate_proj": "ffn_gate.weight",
|
|
"mlp.up_proj": "ffn_up.weight",
|
|
"mlp.down_proj": "ffn_down.weight",
|
|
"input_layernorm": "attn_norm.weight",
|
|
"post_attention_layernorm": "ffn_norm.weight",
|
|
}
|
|
|
|
@classmethod
|
|
def map_tensor_name(cls, original_name: str) -> str | None:
|
|
"""Map original tensor name to GGUF format.
|
|
|
|
Translates HuggingFace tensor naming to GGUF format, handling embeddings,
|
|
attention layers, feed-forward networks, and normalisation layers. Uses
|
|
layer-aware mapping for transformer blocks whilst maintaining consistency
|
|
across different model architectures.
|
|
|
|
Returns:
|
|
GGUF tensor name, or None if unmappable.
|
|
"""
|
|
# Check direct mappings first
|
|
if original_name in cls.DIRECT_MAPPINGS:
|
|
return cls.DIRECT_MAPPINGS[original_name]
|
|
|
|
# Handle layer-specific tensors
|
|
if ".layers." in original_name:
|
|
return cls._map_layer_tensor(original_name)
|
|
|
|
# Return None for unmapped tensors
|
|
return None
|
|
|
|
@classmethod
|
|
def _map_layer_tensor(cls, tensor_name: str) -> str | None:
|
|
"""Map layer-specific tensor names.
|
|
|
|
Handles tensors within transformer layers, extracting layer indices
|
|
and mapping component names to GGUF conventions.
|
|
|
|
Args:
|
|
tensor_name: Layer tensor name containing .layers.N. pattern.
|
|
|
|
Returns:
|
|
Mapped GGUF tensor name, or None if unmappable.
|
|
"""
|
|
# Extract layer number
|
|
parts = tensor_name.split(".")
|
|
layer_idx = None
|
|
for i, part in enumerate(parts):
|
|
if part == "layers" and i + 1 < len(parts):
|
|
layer_idx = parts[i + 1]
|
|
break
|
|
|
|
if layer_idx is None:
|
|
return None
|
|
|
|
# Check each pattern
|
|
for pattern, replacement in cls.LAYER_PATTERNS.items():
|
|
if pattern in tensor_name:
|
|
return f"blk.{layer_idx}.{replacement}"
|
|
|
|
return None
|
|
|
|
|
|
class URLParser:
|
|
"""Parses and validates model URLs from various sources.
|
|
|
|
Handles HuggingFace URLs, Ollama-style GGUF references, and other
|
|
model source formats. Extracts metadata including author, model name,
|
|
and file patterns for appropriate download strategies.
|
|
"""
|
|
|
|
@staticmethod
|
|
def parse(url: str) -> ModelSource:
|
|
"""Parse URL and extract model source information.
|
|
|
|
Analyses URL format to determine source type and extract relevant
|
|
metadata for model download and processing.
|
|
|
|
Args:
|
|
url: Model URL in supported format.
|
|
|
|
Returns:
|
|
ModelSource with parsed information.
|
|
|
|
Raises:
|
|
ValueError: If URL format is not recognised.
|
|
"""
|
|
if not url:
|
|
msg = "URL cannot be empty"
|
|
raise ValueError(msg)
|
|
|
|
# Try Ollama-style GGUF URL first (hf.co/author/model:pattern)
|
|
ollama_match = re.match(r"^hf\.co/([^:]+):(.+)$", url)
|
|
if ollama_match:
|
|
source_model = ollama_match.group(1)
|
|
gguf_pattern = ollama_match.group(2)
|
|
return URLParser._create_model_source(
|
|
url,
|
|
URLType.OLLAMA_GGUF,
|
|
source_model,
|
|
gguf_file_pattern=gguf_pattern,
|
|
is_gguf_repo=True,
|
|
)
|
|
|
|
# Try regular HuggingFace URL
|
|
hf_match = re.match(r"https://huggingface\.co/([^/]+/[^/?]+)", url)
|
|
if hf_match:
|
|
source_model = hf_match.group(1)
|
|
return URLParser._create_model_source(
|
|
url, URLType.HUGGINGFACE, source_model, is_gguf_repo=False
|
|
)
|
|
|
|
msg = (
|
|
"Invalid URL format\n"
|
|
"Supported formats:\n"
|
|
" - https://huggingface.co/username/model-name\n"
|
|
" - hf.co/username/model-name-GGUF:F16"
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
@staticmethod
|
|
def _create_model_source(
|
|
url: str,
|
|
url_type: URLType,
|
|
source_model: str,
|
|
gguf_file_pattern: str | None = None,
|
|
is_gguf_repo: bool = False,
|
|
) -> ModelSource:
|
|
"""Create ModelSource with parsed information.
|
|
|
|
Constructs a ModelSource instance with extracted metadata,
|
|
handling author/model name splitting and GGUF suffix removal.
|
|
|
|
Args:
|
|
url: Original URL.
|
|
url_type: Type of URL (HuggingFace or Ollama GGUF).
|
|
source_model: Repository identifier (author/model).
|
|
gguf_file_pattern: Optional GGUF file pattern.
|
|
is_gguf_repo: Whether this is a GGUF repository.
|
|
|
|
Returns:
|
|
Configured ModelSource instance.
|
|
"""
|
|
author, model_name = source_model.split("/", 1)
|
|
|
|
# Strip -GGUF suffix for GGUF repos
|
|
if is_gguf_repo and model_name.endswith("-GGUF"):
|
|
model_name = model_name[:-5]
|
|
|
|
return ModelSource(
|
|
url=url,
|
|
url_type=url_type,
|
|
source_model=source_model,
|
|
original_author=author,
|
|
model_name=model_name,
|
|
gguf_file_pattern=gguf_file_pattern,
|
|
is_gguf_repo=is_gguf_repo,
|
|
)
|