Initial commit
This commit is contained in:
commit
ef7df1a8c3
28 changed files with 6829 additions and 0 deletions
196
helpers/utils/tensor_mapping.py
Normal file
196
helpers/utils/tensor_mapping.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
"""Tensor mapping and URL parsing utilities.
|
||||
|
||||
Provides utilities for mapping tensor names between different formats,
|
||||
parsing model URLs, and handling architecture-specific conversions.
|
||||
Uses UK English spelling conventions throughout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from helpers.models.quantisation import ModelSource, URLType
|
||||
|
||||
|
||||
class TensorMapper:
|
||||
"""Maps tensor names between HuggingFace and GGUF conventions.
|
||||
|
||||
Provides flexible tensor name translation supporting direct mappings,
|
||||
layer-aware transformations, and architecture-specific overrides.
|
||||
Handles both simple renames and complex pattern-based conversions.
|
||||
"""
|
||||
|
||||
# Common direct mappings across architectures
|
||||
DIRECT_MAPPINGS: ClassVar[dict[str, str]] = {
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
|
||||
# Layer component patterns for transformer blocks
|
||||
LAYER_PATTERNS: ClassVar[dict[str, str]] = {
|
||||
"self_attn.q_proj.weight": "attn_q.weight",
|
||||
"self_attn.q_proj.bias": "attn_q.bias",
|
||||
"self_attn.k_proj.weight": "attn_k.weight",
|
||||
"self_attn.k_proj.bias": "attn_k.bias",
|
||||
"self_attn.v_proj.weight": "attn_v.weight",
|
||||
"self_attn.v_proj.bias": "attn_v.bias",
|
||||
"self_attn.o_proj": "attn_output.weight",
|
||||
"mlp.gate_proj": "ffn_gate.weight",
|
||||
"mlp.up_proj": "ffn_up.weight",
|
||||
"mlp.down_proj": "ffn_down.weight",
|
||||
"input_layernorm": "attn_norm.weight",
|
||||
"post_attention_layernorm": "ffn_norm.weight",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def map_tensor_name(cls, original_name: str) -> str | None:
|
||||
"""Map original tensor name to GGUF format.
|
||||
|
||||
Translates HuggingFace tensor naming to GGUF format, handling embeddings,
|
||||
attention layers, feed-forward networks, and normalisation layers. Uses
|
||||
layer-aware mapping for transformer blocks whilst maintaining consistency
|
||||
across different model architectures.
|
||||
|
||||
Returns:
|
||||
GGUF tensor name, or None if unmappable.
|
||||
"""
|
||||
# Check direct mappings first
|
||||
if original_name in cls.DIRECT_MAPPINGS:
|
||||
return cls.DIRECT_MAPPINGS[original_name]
|
||||
|
||||
# Handle layer-specific tensors
|
||||
if ".layers." in original_name:
|
||||
return cls._map_layer_tensor(original_name)
|
||||
|
||||
# Return None for unmapped tensors
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _map_layer_tensor(cls, tensor_name: str) -> str | None:
|
||||
"""Map layer-specific tensor names.
|
||||
|
||||
Handles tensors within transformer layers, extracting layer indices
|
||||
and mapping component names to GGUF conventions.
|
||||
|
||||
Args:
|
||||
tensor_name: Layer tensor name containing .layers.N. pattern.
|
||||
|
||||
Returns:
|
||||
Mapped GGUF tensor name, or None if unmappable.
|
||||
"""
|
||||
# Extract layer number
|
||||
parts = tensor_name.split(".")
|
||||
layer_idx = None
|
||||
for i, part in enumerate(parts):
|
||||
if part == "layers" and i + 1 < len(parts):
|
||||
layer_idx = parts[i + 1]
|
||||
break
|
||||
|
||||
if layer_idx is None:
|
||||
return None
|
||||
|
||||
# Check each pattern
|
||||
for pattern, replacement in cls.LAYER_PATTERNS.items():
|
||||
if pattern in tensor_name:
|
||||
return f"blk.{layer_idx}.{replacement}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class URLParser:
|
||||
"""Parses and validates model URLs from various sources.
|
||||
|
||||
Handles HuggingFace URLs, Ollama-style GGUF references, and other
|
||||
model source formats. Extracts metadata including author, model name,
|
||||
and file patterns for appropriate download strategies.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def parse(url: str) -> ModelSource:
|
||||
"""Parse URL and extract model source information.
|
||||
|
||||
Analyses URL format to determine source type and extract relevant
|
||||
metadata for model download and processing.
|
||||
|
||||
Args:
|
||||
url: Model URL in supported format.
|
||||
|
||||
Returns:
|
||||
ModelSource with parsed information.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL format is not recognised.
|
||||
"""
|
||||
if not url:
|
||||
msg = "URL cannot be empty"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Try Ollama-style GGUF URL first (hf.co/author/model:pattern)
|
||||
ollama_match = re.match(r"^hf\.co/([^:]+):(.+)$", url)
|
||||
if ollama_match:
|
||||
source_model = ollama_match.group(1)
|
||||
gguf_pattern = ollama_match.group(2)
|
||||
return URLParser._create_model_source(
|
||||
url,
|
||||
URLType.OLLAMA_GGUF,
|
||||
source_model,
|
||||
gguf_file_pattern=gguf_pattern,
|
||||
is_gguf_repo=True,
|
||||
)
|
||||
|
||||
# Try regular HuggingFace URL
|
||||
hf_match = re.match(r"https://huggingface\.co/([^/]+/[^/?]+)", url)
|
||||
if hf_match:
|
||||
source_model = hf_match.group(1)
|
||||
return URLParser._create_model_source(
|
||||
url, URLType.HUGGINGFACE, source_model, is_gguf_repo=False
|
||||
)
|
||||
|
||||
msg = (
|
||||
"Invalid URL format\n"
|
||||
"Supported formats:\n"
|
||||
" - https://huggingface.co/username/model-name\n"
|
||||
" - hf.co/username/model-name-GGUF:F16"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@staticmethod
|
||||
def _create_model_source(
|
||||
url: str,
|
||||
url_type: URLType,
|
||||
source_model: str,
|
||||
gguf_file_pattern: str | None = None,
|
||||
is_gguf_repo: bool = False,
|
||||
) -> ModelSource:
|
||||
"""Create ModelSource with parsed information.
|
||||
|
||||
Constructs a ModelSource instance with extracted metadata,
|
||||
handling author/model name splitting and GGUF suffix removal.
|
||||
|
||||
Args:
|
||||
url: Original URL.
|
||||
url_type: Type of URL (HuggingFace or Ollama GGUF).
|
||||
source_model: Repository identifier (author/model).
|
||||
gguf_file_pattern: Optional GGUF file pattern.
|
||||
is_gguf_repo: Whether this is a GGUF repository.
|
||||
|
||||
Returns:
|
||||
Configured ModelSource instance.
|
||||
"""
|
||||
author, model_name = source_model.split("/", 1)
|
||||
|
||||
# Strip -GGUF suffix for GGUF repos
|
||||
if is_gguf_repo and model_name.endswith("-GGUF"):
|
||||
model_name = model_name[:-5]
|
||||
|
||||
return ModelSource(
|
||||
url=url,
|
||||
url_type=url_type,
|
||||
source_model=source_model,
|
||||
original_author=author,
|
||||
model_name=model_name,
|
||||
gguf_file_pattern=gguf_file_pattern,
|
||||
is_gguf_repo=is_gguf_repo,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue