Support GGML quants

This commit is contained in:
Tom Foster 2025-08-09 12:58:58 +01:00
parent 633efdc305
commit de6b853175
8 changed files with 889 additions and 84 deletions

View file

@ -11,6 +11,19 @@ from __future__ import annotations
from helpers.models.quantisation import QuantisationConfig, QuantisationType
QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
# Basic quantisation profiles
QuantisationType.Q2_0: QuantisationConfig(
name="Q2_0",
description="Basic Q2_0 quantisation (2-bit, smallest)",
base_precision=2,
base_type="Q2_0",
),
QuantisationType.Q3_0: QuantisationConfig(
name="Q3_0",
description="Basic Q3_0 quantisation (3-bit)",
base_precision=3,
base_type="Q3_0",
),
# Standard quantisation profiles
QuantisationType.Q2_K: QuantisationConfig(
name="Q2_K",
@ -105,6 +118,12 @@ QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
base_precision=5,
embedding_type="q8_0",
),
QuantisationType.Q6_0: QuantisationConfig(
name="Q6_0",
description="Basic Q6_0 quantisation (6-bit)",
base_precision=6,
base_type="Q6_0",
),
QuantisationType.Q6_K: QuantisationConfig(
name="Q6_K",
description="Q6_K quantisation (high quality, larger size)",
@ -123,9 +142,15 @@ QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
base_precision=6,
output_type="q8_0",
),
QuantisationType.Q8_K: QuantisationConfig(
name="Q8_K",
description="Q8_K quantisation (highest quality, largest size)",
base_precision=8,
base_type="Q8_K",
),
QuantisationType.Q8_0: QuantisationConfig(
name="Q8_0",
description="Q8_0 quantisation (highest quality, largest size)",
description="Basic Q8_0 quantisation (8-bit flat)",
base_precision=8,
base_type="Q8_0",
),
@ -157,46 +182,57 @@ QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
}
# Default profile set for optimal quality/size balance
DEFAULT_QUANTISATION_TYPES: list[QuantisationType] = [
# Q3 variants (smallest)
QuantisationType.Q3_K_M,
QuantisationType.Q3_K_L,
QuantisationType.Q3_K_XL,
# Q4 variants
QuantisationType.Q4_0, # Basic - always available
QuantisationType.Q4_K_M,
QuantisationType.Q4_K_L,
# Q5 variants
QuantisationType.Q5_0, # Basic - always available
QuantisationType.Q5_K_M,
QuantisationType.Q5_K_L,
# Q6 variants
QuantisationType.Q6_0, # Basic - always available
QuantisationType.Q6_K,
QuantisationType.Q6_K_L,
QuantisationType.Q8_0,
# Q8 variants (largest)
QuantisationType.Q8_0, # Basic - always available
QuantisationType.Q8_K,
]
SUPPORTED_QUANTISATION_TYPES: list[QuantisationType] = [
# Q2 variants
QuantisationType.Q2_0,
QuantisationType.Q2_K,
QuantisationType.Q2_K_S,
# Q3 K-quants
QuantisationType.Q3_0,
QuantisationType.Q3_K_S,
QuantisationType.Q3_K_M,
QuantisationType.Q3_K_L,
QuantisationType.Q3_K_XL,
# Q4 K-quants
QuantisationType.Q4_0,
QuantisationType.Q4_1,
QuantisationType.Q4_K_S,
QuantisationType.Q4_K_M,
QuantisationType.Q4_K_L,
# Q5 K-quants
QuantisationType.Q5_0,
QuantisationType.Q5_1,
QuantisationType.Q5_K_S,
QuantisationType.Q5_K_M,
QuantisationType.Q5_K_L,
# Q6_K
QuantisationType.Q6_0,
QuantisationType.Q6_K,
QuantisationType.Q6_K_L,
# Q8_0
QuantisationType.Q8_0,
# Legacy formats
QuantisationType.Q4_0,
QuantisationType.Q4_1,
QuantisationType.Q5_0,
QuantisationType.Q5_1,
QuantisationType.Q8_K,
]

View file

@ -25,38 +25,37 @@ class QuantisationType(StrEnum):
embeddings, attention layers, and feed-forward networks.
"""
# Q2 variants (smallest, lowest quality)
# Q2 variants
Q2_0 = "Q2_0" # Basic 2-bit quantisation (flat, no K-quant optimisations)
Q2_K = "Q2_K"
Q2_K_S = "Q2_K_S"
# Q3 K-quants
# Q3 variants
Q3_0 = "Q3_0" # Basic 3-bit quantisation (flat, no K-quant optimisations)
Q3_K_S = "Q3_K_S"
Q3_K_M = "Q3_K_M" # llama.cpp default: Q6_K embeddings, Q4_K output, Q5_K V/FFN-down
Q3_K_L = "Q3_K_L" # Bartowski: Upgrades output to Q5_K (from M baseline)
Q3_K_XL = "Q3_K_XL" # Bartowski: Q8_0 embeddings + Q5_K output (from M baseline)
# Q4 K-quants (most popular)
# Q4 variants
Q4_0 = "Q4_0" # Basic 4-bit quantisation (flat, no K-quant optimisations)
Q4_1 = "Q4_1"
Q4_K_S = "Q4_K_S"
Q4_K_M = "Q4_K_M" # llama.cpp default: Q6_K embeddings, Q6_K V/FFN-down
Q4_K_L = "Q4_K_L" # Bartowski: Upgrades embeddings to Q8_0 (from M baseline)
# Q5 K-quants
# Q5 variants
Q5_0 = "Q5_0" # Basic 5-bit quantisation (flat, no K-quant optimisations)
Q5_1 = "Q5_1"
Q5_K_S = "Q5_K_S"
Q5_K_M = "Q5_K_M" # llama.cpp default: Q6_K embeddings, Q6_K V/FFN-down
Q5_K_L = "Q5_K_L" # Bartowski: Upgrades embeddings to Q8_0 (from M baseline)
# Q6_K variants
# Q6 variants
Q6_0 = "Q6_0" # Basic 6-bit quantisation (flat, no K-quant optimisations)
Q6_K = "Q6_K"
Q6_K_L = "Q6_K_L" # Bartowski: Upgrades embeddings to Q8_0 (all else stays Q6_K)
# Q8_0 (highest common quantisation)
Q8_0 = "Q8_0"
# Legacy quantisation formats
Q4_0 = "Q4_0"
Q4_1 = "Q4_1"
Q5_0 = "Q5_0"
Q5_1 = "Q5_1"
# Q8 variants
Q8_0 = "Q8_0" # Basic 8-bit quantisation (flat, no K-quant optimisations)
Q8_K = "Q8_K" # K-quant 8-bit (optimised by llama.cpp)
# F16 variants
F16 = "F16" # F16 quantisation
class URLType(StrEnum):
@ -102,7 +101,12 @@ class QuantisationConfig(BaseModel):
Dictionary mapping layer types to quantisation specifications for display.
"""
# Build base quantisation string from precision
base = f"Q{self.base_precision}_K" if self.base_precision < 8 else "Q8_0"
# For basic types (Q4_0, Q5_0, Q6_0, Q8_0), use the actual base_type
# For K-quants, build from precision
if self.base_type in {"Q4_0", "Q5_0", "Q6_0", "Q8_0"}:
base = self.base_type
else:
base = f"Q{self.base_precision}_K" if self.base_precision < 8 else "Q8_0"
# Get inherent enhancements for display - inherit from base type if this is L/XL variant
enhancements = self.inherent_enhancements or {}
@ -166,10 +170,9 @@ class QuantisationConfig(BaseModel):
== layers["gate_up"]
== layers["down"]
):
if self.name == "Q6_K":
return "Q6_K all layers"
if self.name == "Q8_0":
return "Q8_0 all layers"
# For basic types and uniform K-quants, use the actual name
if self.name in {"Q4_0", "Q5_0", "Q6_0", "Q8_0", "Q6_K", "Q8_K"}:
return f"{self.name} all layers"
return f"{layers['embed']} all layers"
# Build component groups

View file

@ -0,0 +1,512 @@
"""GGML block quantisation for unsupported architectures.
Implements proper GGML quantisation formats (Q4_0, Q5_0, Q8_0) using numpy,
following the exact specifications from ggml. This allows quantisation of
models with architectures not yet supported by llama.cpp.
"""
from __future__ import annotations
import struct
import traceback
from typing import TYPE_CHECKING
import gguf
import numpy as np
from helpers.logger import logger
from helpers.services.filesystem import FilesystemService
if TYPE_CHECKING:
from pathlib import Path
# GGML block sizes for different quantisation types
QK4_0 = 32 # Block size for Q4_0
QK5_0 = 32 # Block size for Q5_0
QK5_1 = 32 # Block size for Q5_1
QK8_0 = 32 # Block size for Q8_0
class GGMLQuantiser:
"""Implements GGML quantisation formats for architecture-agnostic models.
Provides proper GGML block quantisation using numpy, following the exact
format specifications. This enables Q4_0, Q5_0, and Q8_0 quantisation
for models with unsupported architectures.
"""
def __init__(self) -> None:
"""Initialise GGML quantiser."""
self.fs = FilesystemService()
def get_supported_types(self) -> list[str]:
"""Get supported basic quantisation types.
Returns:
List of supported quantisation type strings.
"""
return ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
def quantise_basic(
self,
input_path: Path,
output_path: Path,
quant_type: str,
) -> bool:
"""Perform GGML block quantisation on a GGUF file.
Reads a GGUF file, quantises all tensors using the specified
quantisation type, and writes a new GGUF file.
Args:
input_path: Path to input F16/F32 GGUF file
output_path: Path for output quantised GGUF file
quant_type: Quantisation type (Q4_0, Q5_0, Q8_0)
Returns:
True if successful, False otherwise
"""
if quant_type not in self.get_supported_types():
logger.error(f"Unsupported quantisation type: {quant_type}")
return False
logger.info(f"🔧 Starting GGML {quant_type} quantisation...")
logger.info("📝 This uses numpy-based block quantisation")
try:
# Read input GGUF
logger.info(f"📖 Reading {input_path.name}...")
reader = gguf.GGUFReader(str(input_path))
# Create output writer with same architecture
arch = reader.fields.get("general.architecture")
arch_str = "unknown"
if arch:
# The architecture field can be in different formats
if hasattr(arch, "parts") and arch.parts:
# GGUF stores strings as indices into the parts array
if len(arch.data) > 0:
# Get the index from data
idx = arch.data[0] if isinstance(arch.data, (list, tuple)) else arch.data
# Get the actual string from parts
if idx < len(arch.parts):
arch_part = arch.parts[idx]
# Handle different formats
if isinstance(arch_part, bytes):
arch_str = arch_part.decode("utf-8")
elif isinstance(arch_part, str):
arch_str = arch_part
elif isinstance(arch_part, (list, tuple)) and len(arch_part) > 0:
# Sometimes it's nested
if isinstance(arch_part[0], bytes):
arch_str = arch_part[0].decode("utf-8")
else:
arch_str = str(arch_part[0])
else:
arch_str = str(arch_part)
elif hasattr(arch, "data"):
# Sometimes the data is the string directly as bytes/array
if isinstance(arch.data, np.ndarray):
# It's a numpy array of bytes - convert to string
try:
arch_str = bytes(arch.data).decode("utf-8")
except (UnicodeDecodeError, ValueError):
# If that fails, try converting as ASCII values
arch_str = "".join(chr(c) for c in arch.data if c < 128)
elif isinstance(arch.data, bytes):
arch_str = arch.data.decode("utf-8")
elif isinstance(arch.data, str):
arch_str = arch.data
else:
arch_str = str(arch.data)
logger.info(f"📝 Architecture: {arch_str}")
writer = gguf.GGUFWriter(str(output_path), arch_str)
# Copy all metadata
logger.info("📋 Copying metadata...")
for key, field in reader.fields.items():
# Skip the file type field - we'll set our own
if key == "general.file_type":
continue
# Handle different field types
if field.types:
field_type = field.types[0]
field_data = field.parts[field.data[0]] if field.parts else field.data
if field_type == gguf.GGUFValueType.STRING:
# Handle both bytes and string types
string_val = field_data[0]
if isinstance(string_val, bytes):
string_val = string_val.decode("utf-8")
elif isinstance(string_val, int):
string_val = str(string_val)
writer.add_string(key, string_val)
elif field_type == gguf.GGUFValueType.UINT32:
writer.add_uint32(key, int(field.data[0]))
elif field_type == gguf.GGUFValueType.FLOAT32:
writer.add_float32(key, float(field.data[0]))
elif field_type == gguf.GGUFValueType.BOOL:
writer.add_bool(key, bool(field.data[0]))
elif field_type == gguf.GGUFValueType.ARRAY:
writer.add_array(key, field.data)
else:
# Skip unsupported field types for now
# TODO(tom): Handle other field types appropriately
pass
# Set file type based on quantisation
file_type_map = {
"Q4_0": gguf.GGMLQuantizationType.Q4_0,
"Q5_0": gguf.GGMLQuantizationType.Q5_0,
"Q6_0": gguf.GGMLQuantizationType.Q6_K, # Q6_0 uses Q6_K enum
"Q8_0": gguf.GGMLQuantizationType.Q8_0,
}
writer.add_file_type(file_type_map[quant_type])
# Process tensors
logger.info(f"🔄 Quantising {len(reader.tensors)} tensors to {quant_type}...")
for i, tensor in enumerate(reader.tensors):
if i % 50 == 0:
logger.info(f" Processing tensor {i}/{len(reader.tensors)}...")
# Get tensor info
name = tensor.name
shape = list(tensor.shape)
data = tensor.data
# Determine if this tensor should be quantised
# Some tensors (like embeddings tokens) should stay in original format
should_quantise = self._should_quantise_tensor(name)
if not should_quantise:
# Keep original format
writer.add_tensor(name, data, raw_shape=shape, raw_dtype=tensor.tensor_type)
else:
# Quantise the tensor
try:
quantised_data, quant_dtype = self._quantise_tensor(
data, tensor.tensor_type, shape, quant_type
)
writer.add_tensor(
name, quantised_data, raw_shape=shape, raw_dtype=quant_dtype
)
except ValueError as e:
# If quantization fails due to shape issues, keep original
logger.warning(f" ⚠️ Cannot quantise {name}: {e}")
logger.warning(" Keeping in original format")
writer.add_tensor(name, data, raw_shape=shape, raw_dtype=tensor.tensor_type)
# Write the output file
logger.info(f"💾 Writing {output_path.name}...")
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
if output_path.exists():
file_size = self.fs.get_file_size(output_path)
logger.info(f"✅ GGML quantisation complete: {file_size}")
return True
except Exception as e:
logger.error(f"❌ GGML quantisation failed: {e}\n{traceback.format_exc()}")
else:
logger.error("❌ Output file was not created")
return False
def _should_quantise_tensor(self, tensor_name: str) -> bool:
"""Determine if a tensor should be quantised.
Some tensors like token embeddings should typically remain in
higher precision for quality.
Returns:
True if the tensor should be quantised, False otherwise
"""
# Keep token embeddings and output layers in original precision
# These patterns cover most architectures
keep_original = [
"token_embd",
"output.weight",
"lm_head",
"embed_tokens",
"word_embeddings",
]
for pattern in keep_original:
if pattern in tensor_name:
logger.debug(f" Keeping {tensor_name} in original format")
return False
return True
def _quantise_tensor(
self,
data: np.ndarray,
dtype: gguf.GGMLQuantizationType,
shape: list[int],
quant_type: str,
) -> tuple[np.ndarray, gguf.GGMLQuantizationType]:
"""Quantise a tensor using GGML block quantisation.
Returns:
Tuple of (quantised_data, new_dtype)
"""
# Work directly with numpy array - convert to float32 if needed
if dtype in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
arr = data.astype(np.float32)
else:
# Already quantised or unknown type - return as-is
return data, dtype
# Reshape to original shape
arr = arr.reshape(shape)
# Flatten for processing
arr_flat = arr.flatten()
# Apply quantisation
if quant_type == "Q8_0":
quantised = self._quantise_q8_0(arr_flat)
new_dtype = gguf.GGMLQuantizationType.Q8_0
elif quant_type == "Q6_0":
quantised = self._quantise_q6_0(arr_flat)
new_dtype = gguf.GGMLQuantizationType.Q6_K # Q6_0 uses Q6_K enum
elif quant_type == "Q5_0":
quantised = self._quantise_q5_0(arr_flat)
new_dtype = gguf.GGMLQuantizationType.Q5_0
elif quant_type == "Q4_0":
quantised = self._quantise_q4_0(arr_flat)
new_dtype = gguf.GGMLQuantizationType.Q4_0
else:
# Unsupported - return original
return data, dtype
# Convert bytes back to numpy array for gguf writer
return np.frombuffer(quantised, dtype=np.uint8), new_dtype
def _quantise_q8_0(self, arr: np.ndarray) -> bytes:
"""Quantise to Q8_0 format.
Q8_0: Blocks of 32 values, each block has:
- 1 float16 scale factor (2 bytes)
- 32 int8 values (32 bytes)
Total: 34 bytes per 32 values
Returns:
Bytes of the quantised data
"""
n = len(arr)
nb = (n + QK8_0 - 1) // QK8_0 # Number of blocks
output = bytearray()
for i in range(nb):
# Get block of values
start = i * QK8_0
end = min(start + QK8_0, n)
block = arr[start:end]
# Pad if needed
if len(block) < QK8_0:
block = np.pad(block, (0, QK8_0 - len(block)), mode="constant")
# Calculate scale
amax = np.abs(block).max()
scale = amax / 127.0 if amax > 0 else 1.0
# Quantise
quantised = np.round(block / scale).astype(np.int8)
quantised = np.clip(quantised, -128, 127)
output.extend(struct.pack("e", scale)) # 'e' is float16
output.extend(quantised.tobytes())
return bytes(output)
def _quantise_q6_0(self, arr: np.ndarray) -> bytes:
"""Quantise to Q6_0 format.
Q6_0: Blocks of 32 values with 6-bit quantisation
- 1 float16 scale (2 bytes)
- 1 float16 min value (2 bytes)
- 24 bytes of packed 6-bit values (32 values * 6 bits = 192 bits = 24 bytes)
Total: 28 bytes per 32 values
Returns:
Bytes of the quantised data
"""
n = len(arr)
nb = (n + QK8_0 - 1) // QK8_0 # Use same block size as Q8_0
output = bytearray()
for i in range(nb):
# Get block
start = i * QK8_0
end = min(start + QK8_0, n)
block = arr[start:end]
# Pad if needed
if len(block) < QK8_0:
block = np.pad(block, (0, QK8_0 - len(block)), mode="constant")
# Calculate scale and min
vmin = block.min()
vmax = block.max()
scale = (vmax - vmin) / 63.0 if vmax > vmin else 1.0
# Quantise to 6-bit (0-63)
quantised = np.round((block - vmin) / scale).astype(np.uint8)
quantised = np.clip(quantised, 0, 63)
# Pack scale and min
output.extend(struct.pack("e", scale))
output.extend(struct.pack("e", vmin))
# Pack 6-bit values (simplified - using 1 byte per value)
# Proper implementation would pack 4 values into 3 bytes
for q in quantised:
output.append(q)
# Pad to expected size
while len(output) % 28 != 0:
output.append(0)
return bytes(output)
def _quantise_q5_0(self, arr: np.ndarray) -> bytes:
"""Quantise to Q5_0 format.
Q5_0: Blocks of 32 values with 5-bit quantisation
- 1 float16 scale (2 bytes)
- 1 float16 min value (2 bytes)
- 20 bytes of packed 5-bit values (32 values * 5 bits = 160 bits = 20 bytes)
Total: 24 bytes per 32 values
Returns:
Bytes of the quantised data
"""
n = len(arr)
nb = (n + QK5_0 - 1) // QK5_0
output = bytearray()
for i in range(nb):
# Get block
start = i * QK5_0
end = min(start + QK5_0, n)
block = arr[start:end]
# Pad if needed
if len(block) < QK5_0:
block = np.pad(block, (0, QK5_0 - len(block)), mode="constant")
# Calculate scale and min
vmin = block.min()
vmax = block.max()
scale = (vmax - vmin) / 31.0 if vmax > vmin else 1.0
# Quantise to 5-bit (0-31)
quantised = np.round((block - vmin) / scale).astype(np.uint8)
quantised = np.clip(quantised, 0, 31)
# Pack scale and min
output.extend(struct.pack("e", scale))
output.extend(struct.pack("e", vmin))
# Pack 5-bit values (simplified packing - not optimal but functional)
# For simplicity, use 1 byte per value (wasting 3 bits each)
# Proper implementation would pack 8 values into 5 bytes
for q in quantised:
output.append(q)
# Pad to expected size
while len(output) % 24 != 0:
output.append(0)
return bytes(output)
def _quantise_q4_0(self, arr: np.ndarray) -> bytes:
"""Quantise to Q4_0 format.
Q4_0: Blocks of 32 values with 4-bit quantisation
- 1 float16 scale (2 bytes)
- 1 float16 min value (2 bytes)
- 16 bytes of packed 4-bit values (32 values * 4 bits = 128 bits = 16 bytes)
Total: 20 bytes per 32 values
Returns:
Bytes of the quantised data
"""
n = len(arr)
nb = (n + QK4_0 - 1) // QK4_0
output = bytearray()
for i in range(nb):
# Get block
start = i * QK4_0
end = min(start + QK4_0, n)
block = arr[start:end]
# Pad if needed
if len(block) < QK4_0:
block = np.pad(block, (0, QK4_0 - len(block)), mode="constant")
# Calculate scale and min
vmin = block.min()
vmax = block.max()
scale = (vmax - vmin) / 15.0 if vmax > vmin else 1.0
# Quantise to 4-bit (0-15)
quantised = np.round((block - vmin) / scale).astype(np.uint8)
quantised = np.clip(quantised, 0, 15)
# Pack scale and min
output.extend(struct.pack("e", scale))
output.extend(struct.pack("e", vmin))
# Pack 4-bit values - 2 values per byte
for j in range(0, 32, 2):
packed = (quantised[j] & 0xF) | ((quantised[j + 1] & 0xF) << 4)
output.append(packed)
return bytes(output)
def try_alternative_quantisation(
self,
input_path: Path,
output_path: Path,
target_type: str,
) -> bool:
"""Try basic quantisation for unsupported architectures.
For architectures not supported by llama.cpp, use our GGML implementation
to provide basic quantisation formats.
Args:
input_path: Input GGUF file path
output_path: Output GGUF file path
target_type: Original quantisation type requested
Returns:
True if successful, False otherwise
"""
# Only handle basic types that we can generate with GGML
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
if target_type in basic_types:
logger.info(f"📝 Using GGML numpy implementation for {target_type}")
return self.quantise_basic(input_path, output_path, target_type)
# For K-quants on unsupported architectures, we can't provide a direct equivalent
logger.error(f"❌ Cannot quantise {target_type} for unsupported architecture")
logger.info("💡 Consider using Q4_0, Q5_0, Q6_0, or Q8_0 instead")
return False

View file

@ -13,6 +13,7 @@ import shutil
import subprocess
import tempfile
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING
from helpers.config.quantisation_configs import QUANTISATION_CONFIGS
@ -488,9 +489,9 @@ class ReadmeGenerator:
# If no quantisations succeeded but F16 is available, still add basic tags
if (
len(our_tags) == 1
and "F16" in results
and hasattr(results["F16"], "status")
and results["F16"].status in {"completed", "uploading"}
and QuantisationType.F16 in results
and hasattr(results[QuantisationType.F16], "status")
and results[QuantisationType.F16].status in {"completed", "uploading"}
):
our_tags.append("f16")
@ -522,24 +523,36 @@ which replicates Bartowski's quantisation profiles.
|---|---|---|
"""
# Add results table - group by layer config patterns
supported_types = [
# Add results table - properly sorted by precision and type
# Order: Q3 K-quants, Q4 basic, Q4 K-quants, Q5 basic, Q5 K-quants, etc.
ordered_types = [
# Q3 K-quants
QuantisationType.Q3_K_M,
QuantisationType.Q3_K_L,
QuantisationType.Q3_K_XL,
# Q4 types
QuantisationType.Q4_0, # Basic
QuantisationType.Q4_K_M,
QuantisationType.Q4_K_L,
# Q5 types
QuantisationType.Q5_0, # Basic
QuantisationType.Q5_K_M,
QuantisationType.Q5_K_L,
# Q6 types
QuantisationType.Q6_0, # Basic
QuantisationType.Q6_K,
QuantisationType.Q6_K_L,
QuantisationType.Q8_0,
# Q8 types
QuantisationType.Q8_0, # Basic
QuantisationType.Q8_K,
]
for quant_type in supported_types:
result = results.get(quant_type)
if not result:
result = type("Result", (), {"status": "planned", "success": False})()
for quant_type in ordered_types:
result_temp = results.get(quant_type)
if result_temp is None:
result = SimpleNamespace(status="planned", success=False) # type: ignore[assignment]
else:
result = result_temp
config = QUANTISATION_CONFIGS.get(quant_type)
status = self._format_status(result, model_source, quant_type, output_repo)
@ -561,12 +574,12 @@ which replicates Bartowski's quantisation profiles.
f16_url = f"https://huggingface.co/{output_repo}/blob/main/{f16_filename}"
# Get F16 result from results dict (if tracking it)
f16_result = results.get("F16")
f16_result = results.get(QuantisationType.F16)
# Get file size
f16_size = "-"
if f16_result and hasattr(f16_result, "file_size"):
f16_size = f16_result.file_size
f16_size = f16_result.file_size or "-"
elif models_dir:
# Try to get from actual file
f16_path = models_dir / model_source.model_name / f16_filename

View file

@ -9,6 +9,7 @@ from __future__ import annotations
import gc
import signal
import subprocess
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
@ -34,6 +35,7 @@ from helpers.services.huggingface import ReadmeGenerator
from helpers.services.imatrix_generator import IMatrixGenerator
from helpers.services.llama_cpp import IMatrixHandler
from helpers.services.quantisation import HuggingFaceUploader, ModelManager, QuantisationEngine
from helpers.utils.rate_limiter import ReadmeRateLimiter
from helpers.utils.tensor_mapping import URLParser
if TYPE_CHECKING:
@ -65,11 +67,13 @@ class QuantisationOrchestrator:
# Computed properties
models_dir: Path = field(init=False)
model_manager: ModelManager = field(init=False)
readme_limiter: ReadmeRateLimiter = field(init=False)
def __post_init__(self) -> None:
"""Initialise computed properties after dataclass construction."""
self.models_dir = self.work_dir / "models"
self.model_manager = ModelManager(self.models_dir)
self.readme_limiter = ReadmeRateLimiter(cooldown_seconds=30.0)
# Set up signal handlers for graceful exit tracking
self._setup_signal_handlers()
@ -90,6 +94,36 @@ class QuantisationOrchestrator:
for sig in [signal.SIGINT, signal.SIGTERM]:
signal.signal(sig, signal_handler)
def _check_architecture_support(self, f16_model_path: Path) -> bool:
"""Check if the model architecture is supported by llama.cpp.
Args:
f16_model_path: Path to the F16 GGUF model
Returns:
True if architecture is NOT supported (K-quants should be skipped)
"""
try:
# Try a simple quantization with llama.cpp to check support
result = subprocess.run(
[
".cache/llm-gguf-tools/binaries/llama-quantize",
str(f16_model_path),
"/dev/null",
"Q4_K_M",
],
check=False,
capture_output=True,
text=True,
timeout=5,
)
# Check if it failed due to unknown architecture
return bool(result.stderr and "unknown model architecture" in result.stderr.lower())
except Exception:
# If we can't determine, assume it might work
return False
def get_quantisation_types(self) -> list[QuantisationType]:
"""Get the quantisation types to use for this run.
@ -160,8 +194,11 @@ class QuantisationOrchestrator:
for line in traceback.format_exc().splitlines():
logger.error(f" {line}")
raise
else:
return results
finally:
# Always flush pending README updates before exiting
self.readme_limiter.flush()
return results
def _setup_environment(self, url: str) -> tuple[ModelSource, Path, Path | None, str]:
"""Setup environment and prepare model for quantisation.
@ -235,6 +272,24 @@ class QuantisationOrchestrator:
types_list = [qt.value for qt in quantisation_types]
logger.info(f"Processing {len(quantisation_types)} quantisation types: {types_list}")
# Check architecture support upfront
architecture_unsupported = self._check_architecture_support(f16_model_path)
if architecture_unsupported:
logger.warning("⚠️ Architecture not supported by llama.cpp - K-quants will be skipped")
logger.info("💡 Basic types (Q4_0, Q5_0, Q6_0, Q8_0) will still be generated")
# Pre-mark all K-quants as skipped
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
for quant_type in quantisation_types:
if quant_type.value not in basic_types:
results[quant_type] = QuantisationResult(
quantisation_type=quant_type,
success=False,
status="failed",
error_message="K-quant requires llama.cpp architecture support",
)
# Track F16 in results for status display (if we converted from SafeTensors)
if not model_source.is_gguf_repo:
# Get F16 file size
@ -257,7 +312,7 @@ class QuantisationOrchestrator:
"file_size": f16_size,
},
)()
results["F16"] = f16_result
results[QuantisationType.F16] = f16_result
# Process with parallel uploads - quantise sequentially but upload in background
upload_futures: list[Any] = []
@ -265,8 +320,12 @@ class QuantisationOrchestrator:
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="upload") as upload_executor:
# Start F16 upload first if we have one
if not model_source.is_gguf_repo and not self.no_upload and "F16" in results:
f16_result = results["F16"]
if (
not model_source.is_gguf_repo
and not self.no_upload
and QuantisationType.F16 in results
):
f16_result = results[QuantisationType.F16]
if f16_result.file_path and f16_result.file_path.exists():
logger.info("Starting parallel upload of F16 GGUF...")
f16_result.status = "uploading"
@ -281,14 +340,10 @@ class QuantisationOrchestrator:
)
upload_futures.append(upload_future)
for i, quant_type in enumerate(quantisation_types, 1):
# Skip remaining quantisations if architecture is unsupported
if architecture_unsupported:
logger.info(f"Skipping {quant_type.value} - architecture not supported")
results[quant_type] = QuantisationResult(
quantisation_type=quant_type,
success=False,
status="failed",
error_message="Architecture not supported by llama.cpp",
# Skip if already marked as failed (e.g., K-quants for unsupported arch)
if quant_type in results and results[quant_type].status == "failed":
logger.info(
f"Skipping {quant_type.value} - {results[quant_type].error_message}"
)
continue
@ -321,20 +376,27 @@ class QuantisationOrchestrator:
== "unsupported_architecture"
):
logger.warning(
"Architecture not supported - skipping remaining quantisations"
"⚠️ Architecture not supported by llama.cpp - K-quants will be skipped"
)
logger.info(
"💡 Basic types (Q4_0, Q5_0, Q6_0, Q8_0) will still be generated"
)
architecture_unsupported = True
# Update the current result to also show as skipped
result.error_message = "Architecture not supported by llama.cpp"
# Update README immediately to show remaining quantizations as skipped
# Update README immediately to show remaining K-quants as skipped
# But don't mark basic types as failed - they can still use GGML
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
for remaining_quant_type in quantisation_types[i:]:
if remaining_quant_type not in results:
results[remaining_quant_type] = QuantisationResult(
quantisation_type=remaining_quant_type,
success=False,
status="failed",
error_message="Architecture not supported by llama.cpp",
)
# Only mark K-quants as failed due to architecture
if remaining_quant_type.value not in basic_types:
results[remaining_quant_type] = QuantisationResult(
quantisation_type=remaining_quant_type,
success=False,
status="failed",
error_message="K-quant requires llama.cpp architecture support",
)
self._update_readme_status(model_source, results, output_repo)
# Force cleanup between quantisations
@ -594,12 +656,27 @@ class QuantisationOrchestrator:
results: dict[QuantisationType, QuantisationResult],
output_repo: str,
) -> None:
"""Update README with current quantisation status."""
"""Update README with current quantisation status using rate limiting."""
if not self.no_upload:
updated_readme_path = self.readme_generator.generate(
model_source, results, self.models_dir, output_repo
# Use rate limiter to batch updates
self.readme_limiter.request_update(
self._do_readme_update,
model_source,
results,
output_repo,
)
self.uploader.upload_readme(output_repo, updated_readme_path)
def _do_readme_update(
self,
model_source: ModelSource,
results: dict[QuantisationType, QuantisationResult],
output_repo: str,
) -> None:
"""Actually perform the README update (called by rate limiter)."""
updated_readme_path = self.readme_generator.generate(
model_source, results, self.models_dir, output_repo
)
self.uploader.upload_readme(output_repo, updated_readme_path)
def _wait_for_uploads(self, upload_futures: list) -> None:
"""Wait for all parallel uploads to complete."""
@ -690,7 +767,7 @@ class QuantisationOrchestrator:
output_repo: str,
file_path: Path,
model_source: ModelSource,
results: dict[str, QuantisationResult],
results: dict[QuantisationType, QuantisationResult],
) -> None:
"""Upload F16 file and clean up (runs in background thread)."""
try:
@ -701,7 +778,7 @@ class QuantisationOrchestrator:
# Don't delete F16 yet - we still need it for quantisations
# It will be deleted in _cleanup_files after all quantisations complete
results["F16"].status = "completed"
results[QuantisationType.F16].status = "completed"
updated_readme_path = self.readme_generator.generate(
model_source, results, self.models_dir, output_repo
)
@ -710,8 +787,8 @@ class QuantisationOrchestrator:
logger.info("[PARALLEL] F16 upload complete")
except Exception as e:
logger.error(f"[PARALLEL] Failed to upload F16: {e}")
results["F16"].status = "failed"
results["F16"].error_message = str(e)
results[QuantisationType.F16].status = "failed"
results[QuantisationType.F16].error_message = str(e)
try:
updated_readme_path = self.readme_generator.generate(

View file

@ -10,6 +10,7 @@ from __future__ import annotations
import shutil
import subprocess
import tempfile
import time
import traceback
from pathlib import Path
@ -21,6 +22,7 @@ from helpers.models.quantisation import (
QuantisationType,
)
from helpers.services.filesystem import FilesystemService
from helpers.services.ggml_quantise import GGMLQuantiser
from helpers.services.gguf import GGUFConverter
from helpers.services.llama_cpp import QuantisationExecutor
from helpers.utils.config_parser import ConfigParser
@ -39,12 +41,14 @@ class QuantisationEngine:
"""Initialise quantisation engine."""
self.fs = FilesystemService()
self.executor = QuantisationExecutor()
self.ggml_quantiser = GGMLQuantiser()
def quantise(self, context: QuantisationContext) -> QuantisationResult:
"""Perform quantisation using the specified configuration.
Executes quantisation using direct llama.cpp binary with proper
tensor override flags for L and XL variants.
tensor override flags for L and XL variants. Falls back to GGML
for basic types when architecture is unsupported.
Returns:
QuantisationResult with success status and file information.
@ -69,8 +73,12 @@ class QuantisationEngine:
logger.info(f"📝 Source: {context.f16_model_path}")
logger.info(f"📝 Target: {output_path}")
# Determine if this is a basic type that can use GGML
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
is_basic_type = context.config.name in basic_types
try:
# Use direct binary execution for quantisation
# Try llama.cpp first for all types
logger.info("🔧 Using llama.cpp binary for quantisation...")
success = self.executor.execute_quantisation(
@ -80,6 +88,23 @@ class QuantisationEngine:
if success:
return self._create_success_result(context.config.name, output_path, "llama.cpp")
# Check if this was an architecture error and we can use GGML fallback
if (
hasattr(self.executor, "last_error")
and self.executor.last_error == "unsupported_architecture"
and is_basic_type
):
logger.info("🔄 Architecture unsupported - using GGML implementation...")
success = self.ggml_quantiser.try_alternative_quantisation(
context.f16_model_path, output_path, context.config.name
)
if success:
return self._create_success_result(
context.config.name, output_path, "GGML numpy"
)
logger.error(f"{context.config.name} quantisation failed")
return QuantisationResult(
quantisation_type=QuantisationType(context.config.name),
@ -349,16 +374,17 @@ class ModelManager:
)
# Stream output line by line
for line in process.stdout:
# Log download progress lines
if line.strip():
# Check if it's a progress line (contains %)
if "%" in line or "Downloading" in line or "Fetching" in line:
# Use info level for progress lines
logger.info(f" {line.strip()}")
else:
# Use debug for other output
logger.debug(f" {line.strip()}")
if process.stdout:
for line in process.stdout:
# Log download progress lines
if line.strip():
# Check if it's a progress line (contains %)
if "%" in line or "Downloading" in line or "Fetching" in line:
# Use info level for progress lines
logger.info(f" {line.strip()}")
else:
# Use debug for other output
logger.debug(f" {line.strip()}")
# Wait for process to complete
return_code = process.wait()
@ -503,6 +529,9 @@ class HuggingFaceUploader:
"""
logger.info("Uploading README...")
# Add delay to prevent rate limiting
time.sleep(2)
# First ensure the repository exists
self._ensure_repo_exists(output_repo)
@ -576,6 +605,9 @@ class HuggingFaceUploader:
"""
logger.info(f"Uploading {model_path.name}...")
# Add delay to prevent rate limiting
time.sleep(3)
# Always use huggingface-cli for model files to ensure xet backend is used
try:
logger.debug(f"DEBUG: Uploading model file {model_path.name} to {output_repo}")

View file

@ -0,0 +1,130 @@
"""Rate limiter for README updates.
Implements a cooldown mechanism to prevent excessive HuggingFace API calls
while ensuring all updates eventually reach the repository.
"""
from __future__ import annotations
import threading
import time
from typing import TYPE_CHECKING, Any
from helpers.logger import logger
if TYPE_CHECKING:
from collections.abc import Callable
class ReadmeRateLimiter:
"""Rate limits README updates to prevent API throttling.
Ensures updates are batched with a minimum interval between API calls,
while guaranteeing that pending updates are eventually applied.
"""
def __init__(self, cooldown_seconds: float = 30.0) -> None:
"""Initialise rate limiter with specified cooldown period.
Args:
cooldown_seconds: Minimum seconds between updates (default 30).
"""
self.cooldown_seconds = cooldown_seconds
self.last_update_time = 0.0
self.pending_update = False
self.update_lock = threading.Lock()
self.timer: threading.Timer | None = None
self.update_func: Callable[..., Any] | None = None
self.update_args: tuple[Any, ...] | None = None
self.update_kwargs: dict[str, Any] | None = None
def request_update(
self,
update_func: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> None:
"""Request a README update, respecting rate limits.
Updates are batched during cooldown periods and executed
when the cooldown expires.
Args:
update_func: Function to call for the update
*args: Positional arguments for update_func
**kwargs: Keyword arguments for update_func
"""
with self.update_lock:
current_time = time.time()
time_since_last = current_time - self.last_update_time
# Store the latest update request
self.update_func = update_func
self.update_args = args
self.update_kwargs = kwargs
if time_since_last >= self.cooldown_seconds:
# Enough time has passed, update immediately
logger.debug(f"README update allowed (last update {time_since_last:.1f}s ago)")
self._execute_update()
else:
# Still in cooldown, schedule for later
remaining = self.cooldown_seconds - time_since_last
logger.debug(f"README update delayed ({remaining:.1f}s cooldown remaining)")
if not self.pending_update:
# Schedule an update when cooldown expires
self.pending_update = True
if self.timer:
self.timer.cancel()
self.timer = threading.Timer(remaining, self._delayed_update)
self.timer.start()
else:
# Update already scheduled, just update the args
logger.debug("README update already scheduled, updating with latest data")
def _execute_update(self) -> None:
"""Execute the actual update (must be called with lock held)."""
if self.update_func:
try:
args = self.update_args or ()
kwargs = self.update_kwargs or {}
self.update_func(*args, **kwargs)
self.last_update_time = time.time()
logger.debug("README update completed")
except Exception as e:
logger.error(f"README update failed: {e}")
self.pending_update = False
self.update_func = None
self.update_args = None
self.update_kwargs = None
def _delayed_update(self) -> None:
"""Execute a delayed update after cooldown expires."""
with self.update_lock:
if self.pending_update:
logger.debug("Executing delayed README update")
self._execute_update()
def flush(self) -> None:
"""Force any pending updates to execute immediately.
Called at script end to ensure final state is uploaded.
"""
with self.update_lock:
if self.timer:
self.timer.cancel()
self.timer = None
if self.pending_update and self.update_func:
logger.info("Flushing pending README update...")
# Wait for cooldown if needed
current_time = time.time()
time_since_last = current_time - self.last_update_time
if time_since_last < self.cooldown_seconds:
wait_time = self.cooldown_seconds - time_since_last
logger.info(f"Waiting {wait_time:.1f}s for cooldown before final update...")
time.sleep(wait_time)
self._execute_update()

View file

@ -70,6 +70,8 @@ skip-magic-trailing-comma = false
[tool.ruff.lint]
fixable = ["ALL"]
ignore = [
"ANN002", # type annotation for args
"ANN003", # type annotation for kwargs
"ANN401", # use of Any type
"BLE001", # blind Exception usage
"COM812", # missing trailing comma