231 lines
7.4 KiB
Python
231 lines
7.4 KiB
Python
"""GGUF file reading operations.
|
|
|
|
Provides utilities for reading and extracting information from GGUF files.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import gguf
|
|
import numpy as np
|
|
|
|
from helpers.logger import logger
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
|
|
class GGUFReader:
|
|
"""Reads and extracts information from GGUF files.
|
|
|
|
Provides methods to read metadata, architecture information, and tensors
|
|
from existing GGUF files for inspection or re-quantisation.
|
|
"""
|
|
|
|
def __init__(self, file_path: Path) -> None:
|
|
"""Initialise GGUF reader with file path.
|
|
|
|
Sets up the internal GGUF reader instance for subsequent metadata
|
|
and tensor extraction operations on the specified file.
|
|
"""
|
|
self.file_path = file_path
|
|
self.reader = gguf.GGUFReader(str(file_path))
|
|
|
|
def get_architecture(self) -> str:
|
|
"""Extract architecture string from GGUF file.
|
|
|
|
Returns:
|
|
Architecture string or "unknown" if not found.
|
|
"""
|
|
arch = self.reader.fields.get("general.architecture")
|
|
if not arch:
|
|
return "unknown"
|
|
|
|
# Try extracting from parts array format
|
|
if hasattr(arch, "parts") and arch.parts:
|
|
return self._extract_from_parts(arch)
|
|
|
|
# Try extracting from data field directly
|
|
if hasattr(arch, "data"):
|
|
return self._extract_from_data(arch.data)
|
|
|
|
return "unknown"
|
|
|
|
def _extract_from_parts(self, arch: Any) -> str:
|
|
"""Extract architecture from parts array.
|
|
|
|
Returns:
|
|
Architecture string or "unknown".
|
|
"""
|
|
if len(arch.data) == 0:
|
|
return "unknown"
|
|
|
|
# Get index and validate
|
|
idx = arch.data[0] if isinstance(arch.data, (list, tuple)) else arch.data
|
|
if idx >= len(arch.parts):
|
|
return "unknown"
|
|
|
|
return self._decode_arch_part(arch.parts[idx])
|
|
|
|
def _decode_arch_part(self, arch_part: Any) -> str:
|
|
"""Decode architecture part to string.
|
|
|
|
Returns:
|
|
Decoded architecture string.
|
|
"""
|
|
if isinstance(arch_part, bytes):
|
|
return arch_part.decode("utf-8")
|
|
if isinstance(arch_part, str):
|
|
return arch_part
|
|
if isinstance(arch_part, (list, tuple)) and len(arch_part) > 0:
|
|
# Handle nested format
|
|
if isinstance(arch_part[0], bytes):
|
|
return arch_part[0].decode("utf-8")
|
|
return str(arch_part[0])
|
|
return str(arch_part)
|
|
|
|
def _extract_from_data(self, data: Any) -> str:
|
|
"""Extract architecture from data field.
|
|
|
|
Returns:
|
|
Architecture string or "unknown".
|
|
"""
|
|
if isinstance(data, np.ndarray):
|
|
# Convert numpy array of bytes to string
|
|
try:
|
|
return bytes(data).decode("utf-8")
|
|
except (UnicodeDecodeError, ValueError):
|
|
# Fallback to ASCII conversion
|
|
return "".join(chr(c) for c in data if c < 128)
|
|
if isinstance(data, bytes):
|
|
return data.decode("utf-8")
|
|
if isinstance(data, str):
|
|
return data
|
|
return str(data)
|
|
|
|
def get_metadata(self) -> dict[str, Any]:
|
|
"""Extract all metadata from GGUF file.
|
|
|
|
Returns:
|
|
Dictionary of metadata fields and values.
|
|
"""
|
|
metadata: dict[str, Any] = {}
|
|
|
|
for key, field in self.reader.fields.items():
|
|
if field.types and field.data:
|
|
field_type = field.types[0]
|
|
field_data = field.parts[field.data[0]] if field.parts else field.data
|
|
|
|
# Convert data based on type
|
|
if field_type == gguf.GGUFValueType.STRING:
|
|
if isinstance(field_data, (list, tuple)) and field_data:
|
|
string_value = field_data[0]
|
|
if isinstance(string_value, bytes):
|
|
string_value = string_value.decode("utf-8")
|
|
metadata[key] = string_value
|
|
else:
|
|
metadata[key] = str(field_data)
|
|
elif field_type in {
|
|
gguf.GGUFValueType.UINT32,
|
|
gguf.GGUFValueType.INT32,
|
|
gguf.GGUFValueType.FLOAT32,
|
|
gguf.GGUFValueType.BOOL,
|
|
}:
|
|
metadata[key] = (
|
|
field.data[0] if isinstance(field.data, (list, tuple)) else field.data
|
|
)
|
|
elif field_type == gguf.GGUFValueType.ARRAY:
|
|
metadata[key] = list(field.data)
|
|
|
|
return metadata
|
|
|
|
def get_tensor_info(self) -> list[dict[str, Any]]:
|
|
"""Get information about all tensors in the file.
|
|
|
|
Returns:
|
|
List of tensor info dictionaries with name, shape, and type.
|
|
"""
|
|
tensor_info = []
|
|
|
|
for tensor in self.reader.tensors:
|
|
info = {
|
|
"name": tensor.name,
|
|
"shape": list(tensor.shape),
|
|
"type": tensor.tensor_type.name
|
|
if hasattr(tensor.tensor_type, "name")
|
|
else str(tensor.tensor_type),
|
|
"size_bytes": tensor.data.nbytes
|
|
if hasattr(tensor.data, "nbytes")
|
|
else len(tensor.data),
|
|
}
|
|
tensor_info.append(info)
|
|
|
|
return tensor_info
|
|
|
|
def get_quantisation_type(self) -> str | None:
|
|
"""Get the quantisation type of the GGUF file.
|
|
|
|
Returns:
|
|
Quantisation type string or None if not found.
|
|
"""
|
|
file_type = self.reader.fields.get("general.file_type")
|
|
|
|
if file_type and hasattr(file_type, "data"):
|
|
# Map numeric file type to string
|
|
file_type_value = (
|
|
file_type.data[0] if isinstance(file_type.data, (list, tuple)) else file_type.data
|
|
)
|
|
|
|
# Common file type mappings
|
|
file_type_map = {
|
|
0: "F32",
|
|
1: "F16",
|
|
2: "Q4_0",
|
|
3: "Q4_1",
|
|
7: "Q8_0",
|
|
8: "Q5_0",
|
|
9: "Q5_1",
|
|
10: "Q2_K",
|
|
11: "Q3_K_S",
|
|
12: "Q3_K_M",
|
|
13: "Q3_K_L",
|
|
14: "Q4_K_S",
|
|
15: "Q4_K_M",
|
|
16: "Q5_K_S",
|
|
17: "Q5_K_M",
|
|
18: "Q6_K",
|
|
}
|
|
|
|
return file_type_map.get(int(file_type_value), f"Unknown ({file_type_value})")
|
|
|
|
return None
|
|
|
|
def validate(self) -> bool:
|
|
"""Validate that the GGUF file is properly formatted.
|
|
|
|
Returns:
|
|
True if file is valid, False otherwise.
|
|
"""
|
|
try:
|
|
# Check basic structure
|
|
if not self.reader.fields:
|
|
logger.error("No metadata fields found")
|
|
return False
|
|
|
|
# Check for required fields
|
|
required_fields = ["general.architecture"]
|
|
for field in required_fields:
|
|
if field not in self.reader.fields:
|
|
logger.error(f"Missing required field: {field}")
|
|
return False
|
|
|
|
# Check tensors
|
|
if not self.reader.tensors:
|
|
logger.warning("No tensors found in file")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Validation failed: {e}")
|
|
return False
|
|
else:
|
|
return True
|