llm-gguf-tools/helpers/services/gguf.py
2025-08-09 10:55:42 +01:00

478 lines
18 KiB
Python

"""GGUF file operations service.
Provides unified interface for creating, writing, and manipulating GGUF files.
Consolidates GGUF-specific operations from conversion and quantisation workflows.
Uses UK English spelling conventions throughout.
"""
from __future__ import annotations
import gc
import json
import traceback
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol
import gguf
import torch
from safetensors import safe_open
from helpers.logger import logger
from helpers.services.filesystem import FilesystemService
from helpers.utils.config_parser import ConfigParser
class VisionConfig(Protocol):
"""Protocol for vision model configuration."""
hidden_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
patch_size: int
spatial_merge_size: int
class TensorMapper(Protocol):
"""Protocol for tensor name mapping."""
def map_tensor_name(self, name: str) -> str | None:
"""Map a tensor name to its GGUF equivalent."""
if TYPE_CHECKING:
import numpy as np
from helpers.models.conversion import ModelConfig
class GGUFWriter:
"""Manages GGUF file creation and metadata writing.
Provides high-level interface for GGUF file operations including metadata
configuration, tensor addition, and tokeniser integration. Encapsulates
low-level GGUF library interactions for consistent error handling.
"""
def __init__(self, output_path: Path, architecture: str) -> None:
"""Initialise GGUF writer with output path and architecture.
Creates the underlying GGUF writer instance and prepares for metadata
and tensor addition. Sets up the file structure for the specified
model architecture.
"""
self.output_path = output_path
self.architecture = architecture
self.writer = gguf.GGUFWriter(str(output_path), architecture)
logger.info(f"Created GGUF writer for {architecture} architecture")
def add_metadata(self, model_config: ModelConfig, model_name: str) -> None:
"""Add comprehensive metadata from model configuration.
Writes general model information, architectural parameters, and
quantisation settings to the GGUF file header. Handles both standard
and vision model configurations with appropriate parameter mapping.
"""
# General metadata
self.writer.add_name(model_name)
self.writer.add_description(f"Converted from {model_config.architectures[0]}")
self.writer.add_file_type(gguf.LlamaFileType.ALL_F32)
# Log architecture being used
logger.info(f"Setting GGUF architecture: {self.architecture}")
if self.architecture not in {"llama", "qwen2", "gemma", "phi3", "falcon", "gpt2"}:
logger.warning(f"Architecture '{self.architecture}' may not be supported by llama.cpp")
# Model parameters from config
params = model_config.to_gguf_params()
self.writer.add_context_length(params.context_length)
self.writer.add_embedding_length(params.embedding_length)
self.writer.add_block_count(params.block_count)
self.writer.add_feed_forward_length(params.feed_forward_length)
self.writer.add_head_count(params.attention_head_count)
self.writer.add_head_count_kv(params.attention_head_count_kv)
self.writer.add_layer_norm_rms_eps(params.attention_layer_norm_rms_epsilon)
self.writer.add_rope_freq_base(params.rope_freq_base)
self.writer.add_rope_dimension_count(params.rope_dimension_count)
logger.info(f"Added metadata: {params.block_count} layers, {params.context_length} context")
def add_vision_metadata(self, vision_config: VisionConfig | None) -> None:
"""Add vision model parameters to GGUF metadata.
Configures vision-specific parameters for multimodal models including
embedding dimensions, attention heads, and spatial processing settings.
"""
if not vision_config:
return
logger.info("Adding vision model parameters...")
self.writer.add_vision_embedding_length(vision_config.hidden_size)
self.writer.add_vision_block_count(vision_config.num_hidden_layers)
self.writer.add_vision_head_count(vision_config.num_attention_heads)
self.writer.add_vision_feed_forward_length(vision_config.intermediate_size)
self.writer.add_vision_patch_size(vision_config.patch_size)
self.writer.add_vision_spatial_merge_size(vision_config.spatial_merge_size)
if hasattr(vision_config, "rms_norm_eps") and vision_config.rms_norm_eps:
self.writer.add_vision_attention_layernorm_eps(vision_config.rms_norm_eps)
def add_tokeniser(self, tokeniser_config: dict[str, Any]) -> None:
"""Add tokeniser metadata to GGUF file.
Writes special token IDs and tokeniser model type to enable proper
text processing during inference. Uses sensible defaults for missing
configuration values.
"""
self.writer.add_bos_token_id(tokeniser_config.get("bos_token_id", 1))
self.writer.add_eos_token_id(tokeniser_config.get("eos_token_id", 2))
self.writer.add_unk_token_id(tokeniser_config.get("unk_token_id", 0))
self.writer.add_pad_token_id(tokeniser_config.get("pad_token_id", 0))
# Add BOS/EOS token addition flags if available
if "add_bos_token" in tokeniser_config:
self.writer.add_add_bos_token(tokeniser_config["add_bos_token"])
if "add_eos_token" in tokeniser_config:
self.writer.add_add_eos_token(tokeniser_config["add_eos_token"])
# Note: tokenizer_model is set by add_tokeniser_vocabulary based on actual tokenizer type
logger.info("Added tokeniser configuration")
def add_tokeniser_vocabulary(self, model_path: Path) -> None:
"""Add full tokeniser vocabulary to GGUF file.
Loads and embeds the complete tokeniser vocabulary including tokens,
merges, and scores to enable standalone model usage without external
tokeniser files. Supports BPE, Unigram, and WordPiece tokenizers.
"""
tokenizer_path = model_path / "tokenizer.json"
if not tokenizer_path.exists():
logger.warning("tokenizer.json not found, skipping vocabulary embedding")
return
try:
with Path(tokenizer_path).open(encoding="utf-8") as f:
tokenizer_data = json.load(f)
model_data = tokenizer_data.get("model", {})
model_type = model_data.get("type", "")
# Get pre-tokenizer information
pre_tokenizer = tokenizer_data.get("pre_tokenizer", {})
pre_tokenizer_type = self._get_pre_tokenizer_type(pre_tokenizer)
# Get added tokens
added_tokens = tokenizer_data.get("added_tokens", [])
if model_type == "BPE":
self._add_bpe_tokenizer(model_data, added_tokens, pre_tokenizer_type)
elif model_type == "Unigram":
self._add_unigram_tokenizer(model_data, added_tokens)
elif model_type == "WordPiece":
self._add_wordpiece_tokenizer(model_data, added_tokens)
else:
logger.warning(f"Unsupported tokenizer type: {model_type}")
# Try to add as generic tokenizer
self._add_generic_tokenizer(model_data, tokenizer_data)
except Exception as e:
logger.error(f"Failed to load tokeniser vocabulary: {e}")
logger.error(traceback.format_exc())
def _get_pre_tokenizer_type(self, pre_tokenizer: dict[str, Any]) -> str:
"""Determine pre-tokenizer type from configuration.
Returns:
Pre-tokenizer type.
"""
if not pre_tokenizer:
return "default"
# Check for various pre-tokenizer types
pre_type = pre_tokenizer.get("type", "")
if "ByteLevel" in str(pre_type):
return "llama3"
if "Metaspace" in str(pre_type):
return "default"
return "default"
def _add_bpe_tokenizer(
self, model_data: dict[str, Any], added_tokens: list[dict[str, Any]], pre_type: str
) -> None:
"""Add BPE tokenizer vocabulary to GGUF."""
vocab = model_data.get("vocab", {})
merges = model_data.get("merges", [])
if not vocab:
logger.warning("No vocabulary found in BPE tokenizer")
return
# Create token list sorted by index
max_idx = max(vocab.values()) if vocab else 0
tokens = [""] * (max_idx + 1)
for token, idx in vocab.items():
if 0 <= idx < len(tokens):
tokens[idx] = token
# Handle added tokens
for added_token in added_tokens:
token_id = added_token.get("id")
content = added_token.get("content")
if token_id is not None and content is not None:
if token_id >= len(tokens):
tokens.extend([""] * (token_id - len(tokens) + 1))
tokens[token_id] = content
# Prepare token types
token_types = []
for i, _token in enumerate(tokens):
# Check if it's a special/control token
is_special = any(
added_token.get("id") == i and added_token.get("special", False)
for added_token in added_tokens
)
if is_special:
token_types.append(gguf.TokenType.CONTROL)
else:
token_types.append(gguf.TokenType.NORMAL)
# Add to GGUF
self.writer.add_tokenizer_model("gpt2")
self.writer.add_tokenizer_pre(pre_type)
self.writer.add_token_list(tokens)
self.writer.add_token_scores([0.0] * len(tokens))
self.writer.add_token_types(token_types)
if merges:
self.writer.add_token_merges(merges)
logger.info(f"Added {len(merges)} BPE merges")
logger.info(f"Successfully embedded BPE tokeniser ({len(tokens)} tokens)")
def _add_unigram_tokenizer(
self,
model_data: dict[str, Any],
added_tokens: list[dict[str, Any]], # noqa: ARG002
) -> None:
"""Add Unigram/SentencePiece tokenizer to GGUF."""
vocab = model_data.get("vocab", [])
if not vocab:
logger.warning("No vocabulary found in Unigram tokenizer")
return
tokens = []
scores = []
token_types = []
# Process regular vocabulary
for item in vocab:
if isinstance(item, list) and len(item) >= 2:
token = item[0]
score = float(item[1]) if len(item) > 1 else 0.0
tokens.append(token)
scores.append(score)
# Determine token type
if token.startswith("<") and token.endswith(">"):
token_types.append(gguf.TokenType.CONTROL)
elif len(token) == 6 and token.startswith("<0x") and token.endswith(">"):
token_types.append(gguf.TokenType.BYTE)
else:
token_types.append(gguf.TokenType.NORMAL)
# Add to GGUF
self.writer.add_tokenizer_model("llama")
self.writer.add_tokenizer_pre("default")
self.writer.add_token_list(tokens)
self.writer.add_token_scores(scores)
self.writer.add_token_types(token_types)
logger.info(f"Successfully embedded Unigram tokeniser ({len(tokens)} tokens)")
def _add_wordpiece_tokenizer(
self,
model_data: dict[str, Any],
added_tokens: list[dict[str, Any]], # noqa: ARG002
) -> None:
"""Add WordPiece tokenizer to GGUF."""
vocab = model_data.get("vocab", {})
if not vocab:
logger.warning("No vocabulary found in WordPiece tokenizer")
return
# Create token list sorted by index
max_idx = max(vocab.values()) if vocab else 0
tokens = [""] * (max_idx + 1)
for token, idx in vocab.items():
if 0 <= idx < len(tokens):
tokens[idx] = token
# Token types (all normal for WordPiece)
token_types = [gguf.TokenType.NORMAL] * len(tokens)
# Add to GGUF
self.writer.add_tokenizer_model("bert")
self.writer.add_tokenizer_pre("default")
self.writer.add_token_list(tokens)
self.writer.add_token_scores([0.0] * len(tokens))
self.writer.add_token_types(token_types)
logger.info(f"Successfully embedded WordPiece tokeniser ({len(tokens)} tokens)")
def _add_generic_tokenizer(
self,
model_data: dict[str, Any],
tokenizer_data: dict[str, Any], # noqa: ARG002
) -> None:
"""Try to add a generic tokenizer based on available data."""
vocab = model_data.get("vocab")
if not vocab:
logger.warning("Cannot extract vocabulary from unknown tokenizer type")
return
# Try to extract tokens in a generic way
tokens = []
if isinstance(vocab, dict):
# Dictionary-style vocab
max_idx = max(vocab.values()) if vocab else 0
tokens = [""] * (max_idx + 1)
for token, idx in vocab.items():
if 0 <= idx < len(tokens):
tokens[idx] = token
elif isinstance(vocab, list):
# List-style vocab
for item in vocab:
if isinstance(item, str):
tokens.append(item)
elif isinstance(item, list) and len(item) > 0:
tokens.append(item[0])
if tokens:
self.writer.add_tokenizer_model("llama") # Default to llama
self.writer.add_tokenizer_pre("default")
self.writer.add_token_list(tokens)
self.writer.add_token_scores([0.0] * len(tokens))
self.writer.add_token_types([gguf.TokenType.NORMAL] * len(tokens))
logger.info(f"Added generic tokeniser ({len(tokens)} tokens)")
else:
logger.warning("Could not extract tokens from unknown tokenizer format")
def add_tensor(self, name: str, data: np.ndarray) -> None:
"""Add a tensor to the GGUF file.
Writes tensor data with the specified name to the file. Handles
data type conversions and validates tensor shapes.
"""
self.writer.add_tensor(name, data)
def finalise(self) -> None:
"""Write all data to file and close writer.
Completes the GGUF file creation by writing headers, key-value data,
and tensor data in the correct order. Ensures proper file closure.
"""
logger.info(f"Writing GGUF file to {self.output_path}")
self.writer.write_header_to_file()
self.writer.write_kv_data_to_file()
self.writer.write_tensors_to_file()
self.writer.close()
logger.info("GGUF file written successfully")
class GGUFConverter:
"""High-level GGUF conversion orchestrator.
Coordinates the complete conversion workflow from source models to GGUF
format, managing metadata extraction, tensor mapping, and file writing.
"""
@staticmethod
def convert_safetensors(
model_path: Path,
output_path: Path,
model_config: ModelConfig,
architecture: str,
tensor_mapper: TensorMapper,
) -> bool:
"""Convert SafeTensors model to GGUF format.
Orchestrates the conversion process including metadata setup, tensor
loading with BFloat16 support, name mapping, and tokeniser integration.
Returns:
True if conversion successful, False otherwise.
"""
logger.info(f"Converting {model_path.name} to GGUF...")
# Create writer
writer_wrapper = GGUFWriter(output_path, architecture)
# Add metadata
writer_wrapper.add_metadata(model_config, model_path.name)
# Add vision metadata if present
if model_config.vision_config:
writer_wrapper.add_vision_metadata(model_config.vision_config)
# Load and add tensors
fs = FilesystemService()
tensor_files = fs.find_safetensor_files(model_path)
logger.info(f"Found {len(tensor_files)} tensor file(s)")
tensor_count = 0
for tensor_file in tensor_files:
logger.info(f"Loading {tensor_file.name}...")
with safe_open(tensor_file, framework="pt") as f:
for tensor_name in f.keys(): # noqa: SIM118
tensor_data = f.get_tensor(tensor_name)
# Convert BFloat16 to Float32
if hasattr(tensor_data, "numpy"):
if torch and tensor_data.dtype == torch.bfloat16:
tensor_data = tensor_data.float()
tensor_data = tensor_data.numpy()
# Map tensor name
gguf_name = tensor_mapper.map_tensor_name(tensor_name)
if gguf_name:
writer_wrapper.add_tensor(gguf_name, tensor_data)
tensor_count += 1
if tensor_count % 100 == 0:
logger.info(f" Processed {tensor_count} tensors...")
# Free memory after processing each tensor
del tensor_data
# Force garbage collection after processing each file
gc.collect()
logger.info(f"Total tensors processed: {tensor_count}")
# Add tokeniser configuration
try:
tok_config = ConfigParser.load_tokeniser_config(model_path)
writer_wrapper.add_tokeniser(tok_config)
logger.info("Tokeniser configuration added")
except Exception as e:
logger.warning(f"Could not add tokeniser configuration: {e}")
# Add tokeniser vocabulary (critical for standalone usage)
try:
writer_wrapper.add_tokeniser_vocabulary(model_path)
except Exception as e:
logger.error(f"Failed to embed tokeniser vocabulary: {e}")
logger.error("Model will not work without external tokeniser files!")
# Finalise file
writer_wrapper.finalise()
file_size = fs.get_file_size(output_path)
logger.info(f"Conversion complete! Output: {output_path} ({file_size})")
return True