llm-gguf-tools/helpers/utils/config_parser.py
2025-08-09 17:16:02 +01:00

210 lines
8.3 KiB
Python

"""Configuration parsing utilities.
Provides utilities for parsing model configurations, inferring parameters,
and handling architecture-specific settings. Uses UK English spelling
conventions throughout.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from helpers.filesystem import FilesystemService
from helpers.models.conversion import GGUFParameters, ModelConfig, VisionConfig
if TYPE_CHECKING:
from pathlib import Path
class ConfigParser:
"""Parses and transforms model configuration files.
Handles loading of HuggingFace config.json files, parameter inference,
and conversion to GGUF-compatible formats. Provides sensible defaults
for missing values and architecture-specific handling.
"""
def __init__(self) -> None:
"""Initialise ConfigParser."""
self.fs = FilesystemService()
def load_model_config(self, model_path: Path) -> ModelConfig:
"""Load model configuration from config.json file.
Reads the standard HuggingFace config.json file and parses it into
a structured ModelConfig instance with proper type validation. Handles
vision model configurations and provides sensible defaults for missing values.
Returns:
Parsed ModelConfig instance.
"""
config_file = model_path / "config.json"
raw_config = self.fs.load_json_config(config_file)
# Parse vision config if present
vision_config = None
if "vision_config" in raw_config:
vision_config = VisionConfig(**raw_config["vision_config"])
# Create ModelConfig with parsed values
return ModelConfig(
architectures=raw_config.get("architectures", ["Unknown"]),
model_type=raw_config.get("model_type", "unknown"),
vocab_size=raw_config.get("vocab_size", 32000),
max_position_embeddings=raw_config.get("max_position_embeddings", 2048),
hidden_size=raw_config.get("hidden_size", 4096),
num_hidden_layers=raw_config.get("num_hidden_layers", 32),
intermediate_size=raw_config.get("intermediate_size", 11008),
num_attention_heads=raw_config.get("num_attention_heads", 32),
num_key_value_heads=raw_config.get("num_key_value_heads"),
rope_theta=raw_config.get("rope_theta", 10000.0),
rope_scaling=raw_config.get("rope_scaling"),
rms_norm_eps=raw_config.get("rms_norm_eps", 1e-5),
vision_config=vision_config,
)
def infer_gguf_parameters(self, config: ModelConfig) -> GGUFParameters:
"""Infer GGUF parameters from model configuration.
Translates HuggingFace model configuration to GGUF parameter format,
providing sensible defaults for missing values and handling various
architecture conventions. Calculates derived parameters like RoPE
dimensions and handles grouped-query attention configurations.
Returns:
GGUFParameters with inferred values and proper type validation.
"""
# Calculate derived parameters
num_heads = config.num_attention_heads
embedding_length = config.hidden_size
rope_dimension_count = embedding_length // num_heads
# Handle KV heads (for GQA models)
num_kv_heads = config.num_key_value_heads or num_heads
# Create GGUFParameters using dict with aliases
params_dict = {
"vocab_size": config.vocab_size,
"context_length": config.max_position_embeddings,
"embedding_length": embedding_length,
"block_count": config.num_hidden_layers,
"feed_forward_length": config.intermediate_size,
"attention.head_count": num_heads,
"attention.head_count_kv": num_kv_heads,
"attention.layer_norm_rms_epsilon": config.rms_norm_eps,
"rope.freq_base": config.rope_theta,
"rope.dimension_count": rope_dimension_count,
}
params = GGUFParameters.model_validate(params_dict)
# Add RoPE scaling if present
if config.rope_scaling:
params.rope_scaling_type = config.rope_scaling.get("type", "linear")
params.rope_scaling_factor = config.rope_scaling.get("factor", 1.0)
return params
@staticmethod
def get_architecture_mapping(architecture: str) -> str:
"""Get the GGUF architecture name for a model.
Returns the original architecture name to preserve model identity.
Only maps architectures that are truly compatible.
Returns:
Architecture name for GGUF, preserving original when possible.
"""
# Only map architectures that are ACTUALLY the same
# DO NOT map incompatible architectures
known_compatible = {
"LlamaForCausalLM": "llama",
"MistralForCausalLM": "llama",
"Qwen2ForCausalLM": "qwen2",
"GemmaForCausalLM": "gemma",
"GptOssForCausalLM": "gptoss",
"Phi3ForCausalLM": "phi3",
"FalconForCausalLM": "falcon",
"GPT2LMHeadModel": "gpt2",
"GPTJForCausalLM": "gptj",
"GPTNeoXForCausalLM": "gptneox",
"MPTForCausalLM": "mpt",
"BaichuanForCausalLM": "baichuan",
"StableLMEpochForCausalLM": "stablelm",
}
if architecture in known_compatible:
return known_compatible[architecture]
# For unknown architectures, preserve the original name
# This will make it clear the model needs proper support
# Remove common suffixes to get cleaner architecture name
arch_name = architecture
for suffix in ["ForCausalLM", "LMHeadModel", "ForConditionalGeneration"]:
if arch_name.endswith(suffix):
arch_name = arch_name[: -len(suffix)]
break
arch_name = arch_name.lower()
# Special case: convert "gpt-oss" to "gptoss"
if arch_name == "gpt-oss":
arch_name = "gptoss"
return arch_name
@staticmethod
def load_tokeniser_config(model_path: Path) -> dict[str, Any]:
"""Load tokeniser configuration from model directory.
Reads tokenizer_config.json to extract special token IDs and other
tokenisation parameters required for GGUF metadata. Provides sensible
defaults when configuration files are missing or incomplete.
Returns:
Tokeniser configuration dictionary with token IDs and model type.
"""
fs = FilesystemService()
tokeniser_config_path = model_path / "tokenizer_config.json"
if not tokeniser_config_path.exists():
# Return defaults if no config found
return {
"bos_token_id": 1,
"eos_token_id": 2,
"unk_token_id": 0,
"pad_token_id": 0,
}
config = fs.load_json_config(tokeniser_config_path)
# Try to find special token IDs from added_tokens_decoder
added_tokens = config.get("added_tokens_decoder", {})
eos_token_id = config.get("eos_token_id")
bos_token_id = config.get("bos_token_id")
# If not directly specified, search in added_tokens_decoder
if eos_token_id is None:
for token_id, token_info in added_tokens.items():
if token_info.get("content") == "<|endoftext|>":
eos_token_id = int(token_id)
break
if bos_token_id is None:
for token_id, token_info in added_tokens.items():
if token_info.get("content") in {"<|im_start|>", "<s>", "<|startoftext|>"}:
bos_token_id = int(token_id)
break
# Extract token IDs with better defaults
return {
"bos_token_id": bos_token_id if bos_token_id is not None else 1,
"eos_token_id": eos_token_id if eos_token_id is not None else 2,
"unk_token_id": config.get("unk_token_id", 0),
"pad_token_id": config.get(
"pad_token_id", eos_token_id if eos_token_id is not None else 0
),
"model_type": config.get("model_type", "llama"),
"add_bos_token": config.get("add_bos_token", True),
"add_eos_token": config.get("add_eos_token", False),
}