165 lines
6.5 KiB
Python
165 lines
6.5 KiB
Python
"""Configuration parsing utilities.
|
|
|
|
Provides utilities for parsing model configurations, inferring parameters,
|
|
and handling architecture-specific settings. Uses UK English spelling
|
|
conventions throughout.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from helpers.models.conversion import GGUFParameters, ModelConfig, VisionConfig
|
|
from helpers.services.filesystem import FilesystemService
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
|
|
class ConfigParser:
|
|
"""Parses and transforms model configuration files.
|
|
|
|
Handles loading of HuggingFace config.json files, parameter inference,
|
|
and conversion to GGUF-compatible formats. Provides sensible defaults
|
|
for missing values and architecture-specific handling.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialise ConfigParser."""
|
|
self.fs = FilesystemService()
|
|
|
|
def load_model_config(self, model_path: Path) -> ModelConfig:
|
|
"""Load model configuration from config.json file.
|
|
|
|
Reads the standard HuggingFace config.json file and parses it into
|
|
a structured ModelConfig instance with proper type validation. Handles
|
|
vision model configurations and provides sensible defaults for missing values.
|
|
|
|
Returns:
|
|
Parsed ModelConfig instance.
|
|
"""
|
|
config_file = model_path / "config.json"
|
|
raw_config = self.fs.load_json_config(config_file)
|
|
|
|
# Parse vision config if present
|
|
vision_config = None
|
|
if "vision_config" in raw_config:
|
|
vision_config = VisionConfig(**raw_config["vision_config"])
|
|
|
|
# Create ModelConfig with parsed values
|
|
return ModelConfig(
|
|
architectures=raw_config.get("architectures", ["Unknown"]),
|
|
model_type=raw_config.get("model_type", "unknown"),
|
|
vocab_size=raw_config.get("vocab_size", 32000),
|
|
max_position_embeddings=raw_config.get("max_position_embeddings", 2048),
|
|
hidden_size=raw_config.get("hidden_size", 4096),
|
|
num_hidden_layers=raw_config.get("num_hidden_layers", 32),
|
|
intermediate_size=raw_config.get("intermediate_size", 11008),
|
|
num_attention_heads=raw_config.get("num_attention_heads", 32),
|
|
num_key_value_heads=raw_config.get("num_key_value_heads"),
|
|
rope_theta=raw_config.get("rope_theta", 10000.0),
|
|
rope_scaling=raw_config.get("rope_scaling"),
|
|
rms_norm_eps=raw_config.get("rms_norm_eps", 1e-5),
|
|
vision_config=vision_config,
|
|
)
|
|
|
|
def infer_gguf_parameters(self, config: ModelConfig) -> GGUFParameters:
|
|
"""Infer GGUF parameters from model configuration.
|
|
|
|
Translates HuggingFace model configuration to GGUF parameter format,
|
|
providing sensible defaults for missing values and handling various
|
|
architecture conventions. Calculates derived parameters like RoPE
|
|
dimensions and handles grouped-query attention configurations.
|
|
|
|
Returns:
|
|
GGUFParameters with inferred values and proper type validation.
|
|
"""
|
|
# Calculate derived parameters
|
|
num_heads = config.num_attention_heads
|
|
embedding_length = config.hidden_size
|
|
rope_dimension_count = embedding_length // num_heads
|
|
|
|
# Handle KV heads (for GQA models)
|
|
num_kv_heads = config.num_key_value_heads or num_heads
|
|
|
|
# Create GGUFParameters using dict with aliases
|
|
params_dict = {
|
|
"vocab_size": config.vocab_size,
|
|
"context_length": config.max_position_embeddings,
|
|
"embedding_length": embedding_length,
|
|
"block_count": config.num_hidden_layers,
|
|
"feed_forward_length": config.intermediate_size,
|
|
"attention.head_count": num_heads,
|
|
"attention.head_count_kv": num_kv_heads,
|
|
"attention.layer_norm_rms_epsilon": config.rms_norm_eps,
|
|
"rope.freq_base": config.rope_theta,
|
|
"rope.dimension_count": rope_dimension_count,
|
|
}
|
|
|
|
params = GGUFParameters.model_validate(params_dict)
|
|
|
|
# Add RoPE scaling if present
|
|
if config.rope_scaling:
|
|
params.rope_scaling_type = config.rope_scaling.get("type", "linear")
|
|
params.rope_scaling_factor = config.rope_scaling.get("factor", 1.0)
|
|
|
|
return params
|
|
|
|
@staticmethod
|
|
def get_architecture_mapping(architecture: str) -> str:
|
|
"""Map architecture names to known GGUF architectures.
|
|
|
|
Provides fallback mappings for architectures not directly supported
|
|
by GGUF format, translating them to similar known architectures. This
|
|
enables broader model compatibility whilst maintaining GGUF standards.
|
|
|
|
Returns:
|
|
GGUF-compatible architecture name with appropriate fallback to llama.
|
|
"""
|
|
# Architecture mappings to known GGUF types
|
|
mappings = {
|
|
"DotsOCRForCausalLM": "qwen2", # Similar architecture
|
|
"GptOssForCausalLM": "llama", # Use llama as fallback
|
|
"MistralForCausalLM": "llama", # Mistral is llama-like
|
|
"Qwen2ForCausalLM": "qwen2",
|
|
"LlamaForCausalLM": "llama",
|
|
"GemmaForCausalLM": "gemma",
|
|
"Phi3ForCausalLM": "phi3",
|
|
# Add more mappings as needed
|
|
}
|
|
|
|
return mappings.get(architecture, "llama") # Default to llama
|
|
|
|
@staticmethod
|
|
def load_tokeniser_config(model_path: Path) -> dict[str, Any]:
|
|
"""Load tokeniser configuration from model directory.
|
|
|
|
Reads tokenizer_config.json to extract special token IDs and other
|
|
tokenisation parameters required for GGUF metadata. Provides sensible
|
|
defaults when configuration files are missing or incomplete.
|
|
|
|
Returns:
|
|
Tokeniser configuration dictionary with token IDs and model type.
|
|
"""
|
|
fs = FilesystemService()
|
|
tokeniser_config_path = model_path / "tokenizer_config.json"
|
|
|
|
if not tokeniser_config_path.exists():
|
|
# Return defaults if no config found
|
|
return {
|
|
"bos_token_id": 1,
|
|
"eos_token_id": 2,
|
|
"unk_token_id": 0,
|
|
"pad_token_id": 0,
|
|
}
|
|
|
|
config = fs.load_json_config(tokeniser_config_path)
|
|
|
|
# Extract token IDs with defaults
|
|
return {
|
|
"bos_token_id": config.get("bos_token_id", 1),
|
|
"eos_token_id": config.get("eos_token_id", 2),
|
|
"unk_token_id": config.get("unk_token_id", 0),
|
|
"pad_token_id": config.get("pad_token_id", 0),
|
|
"model_type": config.get("model_type", "llama"),
|
|
}
|