Support GGML quants
This commit is contained in:
parent
633efdc305
commit
de6b853175
8 changed files with 889 additions and 84 deletions
|
@ -11,6 +11,19 @@ from __future__ import annotations
|
|||
from helpers.models.quantisation import QuantisationConfig, QuantisationType
|
||||
|
||||
QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
|
||||
# Basic quantisation profiles
|
||||
QuantisationType.Q2_0: QuantisationConfig(
|
||||
name="Q2_0",
|
||||
description="Basic Q2_0 quantisation (2-bit, smallest)",
|
||||
base_precision=2,
|
||||
base_type="Q2_0",
|
||||
),
|
||||
QuantisationType.Q3_0: QuantisationConfig(
|
||||
name="Q3_0",
|
||||
description="Basic Q3_0 quantisation (3-bit)",
|
||||
base_precision=3,
|
||||
base_type="Q3_0",
|
||||
),
|
||||
# Standard quantisation profiles
|
||||
QuantisationType.Q2_K: QuantisationConfig(
|
||||
name="Q2_K",
|
||||
|
@ -105,6 +118,12 @@ QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
|
|||
base_precision=5,
|
||||
embedding_type="q8_0",
|
||||
),
|
||||
QuantisationType.Q6_0: QuantisationConfig(
|
||||
name="Q6_0",
|
||||
description="Basic Q6_0 quantisation (6-bit)",
|
||||
base_precision=6,
|
||||
base_type="Q6_0",
|
||||
),
|
||||
QuantisationType.Q6_K: QuantisationConfig(
|
||||
name="Q6_K",
|
||||
description="Q6_K quantisation (high quality, larger size)",
|
||||
|
@ -123,9 +142,15 @@ QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
|
|||
base_precision=6,
|
||||
output_type="q8_0",
|
||||
),
|
||||
QuantisationType.Q8_K: QuantisationConfig(
|
||||
name="Q8_K",
|
||||
description="Q8_K quantisation (highest quality, largest size)",
|
||||
base_precision=8,
|
||||
base_type="Q8_K",
|
||||
),
|
||||
QuantisationType.Q8_0: QuantisationConfig(
|
||||
name="Q8_0",
|
||||
description="Q8_0 quantisation (highest quality, largest size)",
|
||||
description="Basic Q8_0 quantisation (8-bit flat)",
|
||||
base_precision=8,
|
||||
base_type="Q8_0",
|
||||
),
|
||||
|
@ -157,46 +182,57 @@ QUANTISATION_CONFIGS: dict[QuantisationType, QuantisationConfig] = {
|
|||
}
|
||||
|
||||
|
||||
# Default profile set for optimal quality/size balance
|
||||
DEFAULT_QUANTISATION_TYPES: list[QuantisationType] = [
|
||||
# Q3 variants (smallest)
|
||||
QuantisationType.Q3_K_M,
|
||||
QuantisationType.Q3_K_L,
|
||||
QuantisationType.Q3_K_XL,
|
||||
# Q4 variants
|
||||
QuantisationType.Q4_0, # Basic - always available
|
||||
QuantisationType.Q4_K_M,
|
||||
QuantisationType.Q4_K_L,
|
||||
# Q5 variants
|
||||
QuantisationType.Q5_0, # Basic - always available
|
||||
QuantisationType.Q5_K_M,
|
||||
QuantisationType.Q5_K_L,
|
||||
# Q6 variants
|
||||
QuantisationType.Q6_0, # Basic - always available
|
||||
QuantisationType.Q6_K,
|
||||
QuantisationType.Q6_K_L,
|
||||
QuantisationType.Q8_0,
|
||||
# Q8 variants (largest)
|
||||
QuantisationType.Q8_0, # Basic - always available
|
||||
QuantisationType.Q8_K,
|
||||
]
|
||||
|
||||
|
||||
SUPPORTED_QUANTISATION_TYPES: list[QuantisationType] = [
|
||||
# Q2 variants
|
||||
QuantisationType.Q2_0,
|
||||
QuantisationType.Q2_K,
|
||||
QuantisationType.Q2_K_S,
|
||||
# Q3 K-quants
|
||||
QuantisationType.Q3_0,
|
||||
QuantisationType.Q3_K_S,
|
||||
QuantisationType.Q3_K_M,
|
||||
QuantisationType.Q3_K_L,
|
||||
QuantisationType.Q3_K_XL,
|
||||
# Q4 K-quants
|
||||
QuantisationType.Q4_0,
|
||||
QuantisationType.Q4_1,
|
||||
QuantisationType.Q4_K_S,
|
||||
QuantisationType.Q4_K_M,
|
||||
QuantisationType.Q4_K_L,
|
||||
# Q5 K-quants
|
||||
QuantisationType.Q5_0,
|
||||
QuantisationType.Q5_1,
|
||||
QuantisationType.Q5_K_S,
|
||||
QuantisationType.Q5_K_M,
|
||||
QuantisationType.Q5_K_L,
|
||||
# Q6_K
|
||||
QuantisationType.Q6_0,
|
||||
QuantisationType.Q6_K,
|
||||
QuantisationType.Q6_K_L,
|
||||
# Q8_0
|
||||
QuantisationType.Q8_0,
|
||||
# Legacy formats
|
||||
QuantisationType.Q4_0,
|
||||
QuantisationType.Q4_1,
|
||||
QuantisationType.Q5_0,
|
||||
QuantisationType.Q5_1,
|
||||
QuantisationType.Q8_K,
|
||||
]
|
||||
|
|
|
@ -25,38 +25,37 @@ class QuantisationType(StrEnum):
|
|||
embeddings, attention layers, and feed-forward networks.
|
||||
"""
|
||||
|
||||
# Q2 variants (smallest, lowest quality)
|
||||
# Q2 variants
|
||||
Q2_0 = "Q2_0" # Basic 2-bit quantisation (flat, no K-quant optimisations)
|
||||
Q2_K = "Q2_K"
|
||||
Q2_K_S = "Q2_K_S"
|
||||
|
||||
# Q3 K-quants
|
||||
# Q3 variants
|
||||
Q3_0 = "Q3_0" # Basic 3-bit quantisation (flat, no K-quant optimisations)
|
||||
Q3_K_S = "Q3_K_S"
|
||||
Q3_K_M = "Q3_K_M" # llama.cpp default: Q6_K embeddings, Q4_K output, Q5_K V/FFN-down
|
||||
Q3_K_L = "Q3_K_L" # Bartowski: Upgrades output to Q5_K (from M baseline)
|
||||
Q3_K_XL = "Q3_K_XL" # Bartowski: Q8_0 embeddings + Q5_K output (from M baseline)
|
||||
|
||||
# Q4 K-quants (most popular)
|
||||
# Q4 variants
|
||||
Q4_0 = "Q4_0" # Basic 4-bit quantisation (flat, no K-quant optimisations)
|
||||
Q4_1 = "Q4_1"
|
||||
Q4_K_S = "Q4_K_S"
|
||||
Q4_K_M = "Q4_K_M" # llama.cpp default: Q6_K embeddings, Q6_K V/FFN-down
|
||||
Q4_K_L = "Q4_K_L" # Bartowski: Upgrades embeddings to Q8_0 (from M baseline)
|
||||
|
||||
# Q5 K-quants
|
||||
# Q5 variants
|
||||
Q5_0 = "Q5_0" # Basic 5-bit quantisation (flat, no K-quant optimisations)
|
||||
Q5_1 = "Q5_1"
|
||||
Q5_K_S = "Q5_K_S"
|
||||
Q5_K_M = "Q5_K_M" # llama.cpp default: Q6_K embeddings, Q6_K V/FFN-down
|
||||
Q5_K_L = "Q5_K_L" # Bartowski: Upgrades embeddings to Q8_0 (from M baseline)
|
||||
|
||||
# Q6_K variants
|
||||
# Q6 variants
|
||||
Q6_0 = "Q6_0" # Basic 6-bit quantisation (flat, no K-quant optimisations)
|
||||
Q6_K = "Q6_K"
|
||||
Q6_K_L = "Q6_K_L" # Bartowski: Upgrades embeddings to Q8_0 (all else stays Q6_K)
|
||||
|
||||
# Q8_0 (highest common quantisation)
|
||||
Q8_0 = "Q8_0"
|
||||
|
||||
# Legacy quantisation formats
|
||||
Q4_0 = "Q4_0"
|
||||
Q4_1 = "Q4_1"
|
||||
Q5_0 = "Q5_0"
|
||||
Q5_1 = "Q5_1"
|
||||
# Q8 variants
|
||||
Q8_0 = "Q8_0" # Basic 8-bit quantisation (flat, no K-quant optimisations)
|
||||
Q8_K = "Q8_K" # K-quant 8-bit (optimised by llama.cpp)
|
||||
# F16 variants
|
||||
F16 = "F16" # F16 quantisation
|
||||
|
||||
|
||||
class URLType(StrEnum):
|
||||
|
@ -102,7 +101,12 @@ class QuantisationConfig(BaseModel):
|
|||
Dictionary mapping layer types to quantisation specifications for display.
|
||||
"""
|
||||
# Build base quantisation string from precision
|
||||
base = f"Q{self.base_precision}_K" if self.base_precision < 8 else "Q8_0"
|
||||
# For basic types (Q4_0, Q5_0, Q6_0, Q8_0), use the actual base_type
|
||||
# For K-quants, build from precision
|
||||
if self.base_type in {"Q4_0", "Q5_0", "Q6_0", "Q8_0"}:
|
||||
base = self.base_type
|
||||
else:
|
||||
base = f"Q{self.base_precision}_K" if self.base_precision < 8 else "Q8_0"
|
||||
|
||||
# Get inherent enhancements for display - inherit from base type if this is L/XL variant
|
||||
enhancements = self.inherent_enhancements or {}
|
||||
|
@ -166,10 +170,9 @@ class QuantisationConfig(BaseModel):
|
|||
== layers["gate_up"]
|
||||
== layers["down"]
|
||||
):
|
||||
if self.name == "Q6_K":
|
||||
return "Q6_K all layers"
|
||||
if self.name == "Q8_0":
|
||||
return "Q8_0 all layers"
|
||||
# For basic types and uniform K-quants, use the actual name
|
||||
if self.name in {"Q4_0", "Q5_0", "Q6_0", "Q8_0", "Q6_K", "Q8_K"}:
|
||||
return f"{self.name} all layers"
|
||||
return f"{layers['embed']} all layers"
|
||||
|
||||
# Build component groups
|
||||
|
|
512
helpers/services/ggml_quantise.py
Normal file
512
helpers/services/ggml_quantise.py
Normal file
|
@ -0,0 +1,512 @@
|
|||
"""GGML block quantisation for unsupported architectures.
|
||||
|
||||
Implements proper GGML quantisation formats (Q4_0, Q5_0, Q8_0) using numpy,
|
||||
following the exact specifications from ggml. This allows quantisation of
|
||||
models with architectures not yet supported by llama.cpp.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import gguf
|
||||
import numpy as np
|
||||
|
||||
from helpers.logger import logger
|
||||
from helpers.services.filesystem import FilesystemService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# GGML block sizes for different quantisation types
|
||||
QK4_0 = 32 # Block size for Q4_0
|
||||
QK5_0 = 32 # Block size for Q5_0
|
||||
QK5_1 = 32 # Block size for Q5_1
|
||||
QK8_0 = 32 # Block size for Q8_0
|
||||
|
||||
|
||||
class GGMLQuantiser:
|
||||
"""Implements GGML quantisation formats for architecture-agnostic models.
|
||||
|
||||
Provides proper GGML block quantisation using numpy, following the exact
|
||||
format specifications. This enables Q4_0, Q5_0, and Q8_0 quantisation
|
||||
for models with unsupported architectures.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise GGML quantiser."""
|
||||
self.fs = FilesystemService()
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Get supported basic quantisation types.
|
||||
|
||||
Returns:
|
||||
List of supported quantisation type strings.
|
||||
"""
|
||||
return ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
|
||||
|
||||
def quantise_basic(
|
||||
self,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
quant_type: str,
|
||||
) -> bool:
|
||||
"""Perform GGML block quantisation on a GGUF file.
|
||||
|
||||
Reads a GGUF file, quantises all tensors using the specified
|
||||
quantisation type, and writes a new GGUF file.
|
||||
|
||||
Args:
|
||||
input_path: Path to input F16/F32 GGUF file
|
||||
output_path: Path for output quantised GGUF file
|
||||
quant_type: Quantisation type (Q4_0, Q5_0, Q8_0)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if quant_type not in self.get_supported_types():
|
||||
logger.error(f"Unsupported quantisation type: {quant_type}")
|
||||
return False
|
||||
|
||||
logger.info(f"🔧 Starting GGML {quant_type} quantisation...")
|
||||
logger.info("📝 This uses numpy-based block quantisation")
|
||||
|
||||
try:
|
||||
# Read input GGUF
|
||||
logger.info(f"📖 Reading {input_path.name}...")
|
||||
reader = gguf.GGUFReader(str(input_path))
|
||||
|
||||
# Create output writer with same architecture
|
||||
arch = reader.fields.get("general.architecture")
|
||||
arch_str = "unknown"
|
||||
|
||||
if arch:
|
||||
# The architecture field can be in different formats
|
||||
if hasattr(arch, "parts") and arch.parts:
|
||||
# GGUF stores strings as indices into the parts array
|
||||
if len(arch.data) > 0:
|
||||
# Get the index from data
|
||||
idx = arch.data[0] if isinstance(arch.data, (list, tuple)) else arch.data
|
||||
|
||||
# Get the actual string from parts
|
||||
if idx < len(arch.parts):
|
||||
arch_part = arch.parts[idx]
|
||||
|
||||
# Handle different formats
|
||||
if isinstance(arch_part, bytes):
|
||||
arch_str = arch_part.decode("utf-8")
|
||||
elif isinstance(arch_part, str):
|
||||
arch_str = arch_part
|
||||
elif isinstance(arch_part, (list, tuple)) and len(arch_part) > 0:
|
||||
# Sometimes it's nested
|
||||
if isinstance(arch_part[0], bytes):
|
||||
arch_str = arch_part[0].decode("utf-8")
|
||||
else:
|
||||
arch_str = str(arch_part[0])
|
||||
else:
|
||||
arch_str = str(arch_part)
|
||||
elif hasattr(arch, "data"):
|
||||
# Sometimes the data is the string directly as bytes/array
|
||||
if isinstance(arch.data, np.ndarray):
|
||||
# It's a numpy array of bytes - convert to string
|
||||
try:
|
||||
arch_str = bytes(arch.data).decode("utf-8")
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
# If that fails, try converting as ASCII values
|
||||
arch_str = "".join(chr(c) for c in arch.data if c < 128)
|
||||
elif isinstance(arch.data, bytes):
|
||||
arch_str = arch.data.decode("utf-8")
|
||||
elif isinstance(arch.data, str):
|
||||
arch_str = arch.data
|
||||
else:
|
||||
arch_str = str(arch.data)
|
||||
|
||||
logger.info(f"📝 Architecture: {arch_str}")
|
||||
writer = gguf.GGUFWriter(str(output_path), arch_str)
|
||||
|
||||
# Copy all metadata
|
||||
logger.info("📋 Copying metadata...")
|
||||
for key, field in reader.fields.items():
|
||||
# Skip the file type field - we'll set our own
|
||||
if key == "general.file_type":
|
||||
continue
|
||||
|
||||
# Handle different field types
|
||||
if field.types:
|
||||
field_type = field.types[0]
|
||||
field_data = field.parts[field.data[0]] if field.parts else field.data
|
||||
|
||||
if field_type == gguf.GGUFValueType.STRING:
|
||||
# Handle both bytes and string types
|
||||
string_val = field_data[0]
|
||||
if isinstance(string_val, bytes):
|
||||
string_val = string_val.decode("utf-8")
|
||||
elif isinstance(string_val, int):
|
||||
string_val = str(string_val)
|
||||
writer.add_string(key, string_val)
|
||||
elif field_type == gguf.GGUFValueType.UINT32:
|
||||
writer.add_uint32(key, int(field.data[0]))
|
||||
elif field_type == gguf.GGUFValueType.FLOAT32:
|
||||
writer.add_float32(key, float(field.data[0]))
|
||||
elif field_type == gguf.GGUFValueType.BOOL:
|
||||
writer.add_bool(key, bool(field.data[0]))
|
||||
elif field_type == gguf.GGUFValueType.ARRAY:
|
||||
writer.add_array(key, field.data)
|
||||
else:
|
||||
# Skip unsupported field types for now
|
||||
# TODO(tom): Handle other field types appropriately
|
||||
pass
|
||||
|
||||
# Set file type based on quantisation
|
||||
file_type_map = {
|
||||
"Q4_0": gguf.GGMLQuantizationType.Q4_0,
|
||||
"Q5_0": gguf.GGMLQuantizationType.Q5_0,
|
||||
"Q6_0": gguf.GGMLQuantizationType.Q6_K, # Q6_0 uses Q6_K enum
|
||||
"Q8_0": gguf.GGMLQuantizationType.Q8_0,
|
||||
}
|
||||
writer.add_file_type(file_type_map[quant_type])
|
||||
|
||||
# Process tensors
|
||||
logger.info(f"🔄 Quantising {len(reader.tensors)} tensors to {quant_type}...")
|
||||
|
||||
for i, tensor in enumerate(reader.tensors):
|
||||
if i % 50 == 0:
|
||||
logger.info(f" Processing tensor {i}/{len(reader.tensors)}...")
|
||||
|
||||
# Get tensor info
|
||||
name = tensor.name
|
||||
shape = list(tensor.shape)
|
||||
data = tensor.data
|
||||
|
||||
# Determine if this tensor should be quantised
|
||||
# Some tensors (like embeddings tokens) should stay in original format
|
||||
should_quantise = self._should_quantise_tensor(name)
|
||||
|
||||
if not should_quantise:
|
||||
# Keep original format
|
||||
writer.add_tensor(name, data, raw_shape=shape, raw_dtype=tensor.tensor_type)
|
||||
else:
|
||||
# Quantise the tensor
|
||||
try:
|
||||
quantised_data, quant_dtype = self._quantise_tensor(
|
||||
data, tensor.tensor_type, shape, quant_type
|
||||
)
|
||||
writer.add_tensor(
|
||||
name, quantised_data, raw_shape=shape, raw_dtype=quant_dtype
|
||||
)
|
||||
except ValueError as e:
|
||||
# If quantization fails due to shape issues, keep original
|
||||
logger.warning(f" ⚠️ Cannot quantise {name}: {e}")
|
||||
logger.warning(" Keeping in original format")
|
||||
writer.add_tensor(name, data, raw_shape=shape, raw_dtype=tensor.tensor_type)
|
||||
|
||||
# Write the output file
|
||||
logger.info(f"💾 Writing {output_path.name}...")
|
||||
writer.write_header_to_file()
|
||||
writer.write_kv_data_to_file()
|
||||
writer.write_tensors_to_file()
|
||||
writer.close()
|
||||
|
||||
if output_path.exists():
|
||||
file_size = self.fs.get_file_size(output_path)
|
||||
logger.info(f"✅ GGML quantisation complete: {file_size}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ GGML quantisation failed: {e}\n{traceback.format_exc()}")
|
||||
else:
|
||||
logger.error("❌ Output file was not created")
|
||||
return False
|
||||
|
||||
def _should_quantise_tensor(self, tensor_name: str) -> bool:
|
||||
"""Determine if a tensor should be quantised.
|
||||
|
||||
Some tensors like token embeddings should typically remain in
|
||||
higher precision for quality.
|
||||
|
||||
Returns:
|
||||
True if the tensor should be quantised, False otherwise
|
||||
"""
|
||||
# Keep token embeddings and output layers in original precision
|
||||
# These patterns cover most architectures
|
||||
keep_original = [
|
||||
"token_embd",
|
||||
"output.weight",
|
||||
"lm_head",
|
||||
"embed_tokens",
|
||||
"word_embeddings",
|
||||
]
|
||||
|
||||
for pattern in keep_original:
|
||||
if pattern in tensor_name:
|
||||
logger.debug(f" Keeping {tensor_name} in original format")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _quantise_tensor(
|
||||
self,
|
||||
data: np.ndarray,
|
||||
dtype: gguf.GGMLQuantizationType,
|
||||
shape: list[int],
|
||||
quant_type: str,
|
||||
) -> tuple[np.ndarray, gguf.GGMLQuantizationType]:
|
||||
"""Quantise a tensor using GGML block quantisation.
|
||||
|
||||
Returns:
|
||||
Tuple of (quantised_data, new_dtype)
|
||||
"""
|
||||
# Work directly with numpy array - convert to float32 if needed
|
||||
if dtype in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
||||
arr = data.astype(np.float32)
|
||||
else:
|
||||
# Already quantised or unknown type - return as-is
|
||||
return data, dtype
|
||||
|
||||
# Reshape to original shape
|
||||
arr = arr.reshape(shape)
|
||||
|
||||
# Flatten for processing
|
||||
arr_flat = arr.flatten()
|
||||
|
||||
# Apply quantisation
|
||||
if quant_type == "Q8_0":
|
||||
quantised = self._quantise_q8_0(arr_flat)
|
||||
new_dtype = gguf.GGMLQuantizationType.Q8_0
|
||||
elif quant_type == "Q6_0":
|
||||
quantised = self._quantise_q6_0(arr_flat)
|
||||
new_dtype = gguf.GGMLQuantizationType.Q6_K # Q6_0 uses Q6_K enum
|
||||
elif quant_type == "Q5_0":
|
||||
quantised = self._quantise_q5_0(arr_flat)
|
||||
new_dtype = gguf.GGMLQuantizationType.Q5_0
|
||||
elif quant_type == "Q4_0":
|
||||
quantised = self._quantise_q4_0(arr_flat)
|
||||
new_dtype = gguf.GGMLQuantizationType.Q4_0
|
||||
else:
|
||||
# Unsupported - return original
|
||||
return data, dtype
|
||||
|
||||
# Convert bytes back to numpy array for gguf writer
|
||||
return np.frombuffer(quantised, dtype=np.uint8), new_dtype
|
||||
|
||||
def _quantise_q8_0(self, arr: np.ndarray) -> bytes:
|
||||
"""Quantise to Q8_0 format.
|
||||
|
||||
Q8_0: Blocks of 32 values, each block has:
|
||||
- 1 float16 scale factor (2 bytes)
|
||||
- 32 int8 values (32 bytes)
|
||||
Total: 34 bytes per 32 values
|
||||
|
||||
Returns:
|
||||
Bytes of the quantised data
|
||||
"""
|
||||
n = len(arr)
|
||||
nb = (n + QK8_0 - 1) // QK8_0 # Number of blocks
|
||||
|
||||
output = bytearray()
|
||||
|
||||
for i in range(nb):
|
||||
# Get block of values
|
||||
start = i * QK8_0
|
||||
end = min(start + QK8_0, n)
|
||||
block = arr[start:end]
|
||||
|
||||
# Pad if needed
|
||||
if len(block) < QK8_0:
|
||||
block = np.pad(block, (0, QK8_0 - len(block)), mode="constant")
|
||||
|
||||
# Calculate scale
|
||||
amax = np.abs(block).max()
|
||||
scale = amax / 127.0 if amax > 0 else 1.0
|
||||
|
||||
# Quantise
|
||||
quantised = np.round(block / scale).astype(np.int8)
|
||||
quantised = np.clip(quantised, -128, 127)
|
||||
|
||||
output.extend(struct.pack("e", scale)) # 'e' is float16
|
||||
output.extend(quantised.tobytes())
|
||||
|
||||
return bytes(output)
|
||||
|
||||
def _quantise_q6_0(self, arr: np.ndarray) -> bytes:
|
||||
"""Quantise to Q6_0 format.
|
||||
|
||||
Q6_0: Blocks of 32 values with 6-bit quantisation
|
||||
- 1 float16 scale (2 bytes)
|
||||
- 1 float16 min value (2 bytes)
|
||||
- 24 bytes of packed 6-bit values (32 values * 6 bits = 192 bits = 24 bytes)
|
||||
Total: 28 bytes per 32 values
|
||||
|
||||
Returns:
|
||||
Bytes of the quantised data
|
||||
"""
|
||||
n = len(arr)
|
||||
nb = (n + QK8_0 - 1) // QK8_0 # Use same block size as Q8_0
|
||||
|
||||
output = bytearray()
|
||||
|
||||
for i in range(nb):
|
||||
# Get block
|
||||
start = i * QK8_0
|
||||
end = min(start + QK8_0, n)
|
||||
block = arr[start:end]
|
||||
|
||||
# Pad if needed
|
||||
if len(block) < QK8_0:
|
||||
block = np.pad(block, (0, QK8_0 - len(block)), mode="constant")
|
||||
|
||||
# Calculate scale and min
|
||||
vmin = block.min()
|
||||
vmax = block.max()
|
||||
scale = (vmax - vmin) / 63.0 if vmax > vmin else 1.0
|
||||
|
||||
# Quantise to 6-bit (0-63)
|
||||
quantised = np.round((block - vmin) / scale).astype(np.uint8)
|
||||
quantised = np.clip(quantised, 0, 63)
|
||||
|
||||
# Pack scale and min
|
||||
output.extend(struct.pack("e", scale))
|
||||
output.extend(struct.pack("e", vmin))
|
||||
|
||||
# Pack 6-bit values (simplified - using 1 byte per value)
|
||||
# Proper implementation would pack 4 values into 3 bytes
|
||||
for q in quantised:
|
||||
output.append(q)
|
||||
|
||||
# Pad to expected size
|
||||
while len(output) % 28 != 0:
|
||||
output.append(0)
|
||||
|
||||
return bytes(output)
|
||||
|
||||
def _quantise_q5_0(self, arr: np.ndarray) -> bytes:
|
||||
"""Quantise to Q5_0 format.
|
||||
|
||||
Q5_0: Blocks of 32 values with 5-bit quantisation
|
||||
- 1 float16 scale (2 bytes)
|
||||
- 1 float16 min value (2 bytes)
|
||||
- 20 bytes of packed 5-bit values (32 values * 5 bits = 160 bits = 20 bytes)
|
||||
Total: 24 bytes per 32 values
|
||||
|
||||
Returns:
|
||||
Bytes of the quantised data
|
||||
"""
|
||||
n = len(arr)
|
||||
nb = (n + QK5_0 - 1) // QK5_0
|
||||
|
||||
output = bytearray()
|
||||
|
||||
for i in range(nb):
|
||||
# Get block
|
||||
start = i * QK5_0
|
||||
end = min(start + QK5_0, n)
|
||||
block = arr[start:end]
|
||||
|
||||
# Pad if needed
|
||||
if len(block) < QK5_0:
|
||||
block = np.pad(block, (0, QK5_0 - len(block)), mode="constant")
|
||||
|
||||
# Calculate scale and min
|
||||
vmin = block.min()
|
||||
vmax = block.max()
|
||||
scale = (vmax - vmin) / 31.0 if vmax > vmin else 1.0
|
||||
|
||||
# Quantise to 5-bit (0-31)
|
||||
quantised = np.round((block - vmin) / scale).astype(np.uint8)
|
||||
quantised = np.clip(quantised, 0, 31)
|
||||
|
||||
# Pack scale and min
|
||||
output.extend(struct.pack("e", scale))
|
||||
output.extend(struct.pack("e", vmin))
|
||||
|
||||
# Pack 5-bit values (simplified packing - not optimal but functional)
|
||||
# For simplicity, use 1 byte per value (wasting 3 bits each)
|
||||
# Proper implementation would pack 8 values into 5 bytes
|
||||
for q in quantised:
|
||||
output.append(q)
|
||||
|
||||
# Pad to expected size
|
||||
while len(output) % 24 != 0:
|
||||
output.append(0)
|
||||
|
||||
return bytes(output)
|
||||
|
||||
def _quantise_q4_0(self, arr: np.ndarray) -> bytes:
|
||||
"""Quantise to Q4_0 format.
|
||||
|
||||
Q4_0: Blocks of 32 values with 4-bit quantisation
|
||||
- 1 float16 scale (2 bytes)
|
||||
- 1 float16 min value (2 bytes)
|
||||
- 16 bytes of packed 4-bit values (32 values * 4 bits = 128 bits = 16 bytes)
|
||||
Total: 20 bytes per 32 values
|
||||
|
||||
Returns:
|
||||
Bytes of the quantised data
|
||||
"""
|
||||
n = len(arr)
|
||||
nb = (n + QK4_0 - 1) // QK4_0
|
||||
|
||||
output = bytearray()
|
||||
|
||||
for i in range(nb):
|
||||
# Get block
|
||||
start = i * QK4_0
|
||||
end = min(start + QK4_0, n)
|
||||
block = arr[start:end]
|
||||
|
||||
# Pad if needed
|
||||
if len(block) < QK4_0:
|
||||
block = np.pad(block, (0, QK4_0 - len(block)), mode="constant")
|
||||
|
||||
# Calculate scale and min
|
||||
vmin = block.min()
|
||||
vmax = block.max()
|
||||
scale = (vmax - vmin) / 15.0 if vmax > vmin else 1.0
|
||||
|
||||
# Quantise to 4-bit (0-15)
|
||||
quantised = np.round((block - vmin) / scale).astype(np.uint8)
|
||||
quantised = np.clip(quantised, 0, 15)
|
||||
|
||||
# Pack scale and min
|
||||
output.extend(struct.pack("e", scale))
|
||||
output.extend(struct.pack("e", vmin))
|
||||
|
||||
# Pack 4-bit values - 2 values per byte
|
||||
for j in range(0, 32, 2):
|
||||
packed = (quantised[j] & 0xF) | ((quantised[j + 1] & 0xF) << 4)
|
||||
output.append(packed)
|
||||
|
||||
return bytes(output)
|
||||
|
||||
def try_alternative_quantisation(
|
||||
self,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
target_type: str,
|
||||
) -> bool:
|
||||
"""Try basic quantisation for unsupported architectures.
|
||||
|
||||
For architectures not supported by llama.cpp, use our GGML implementation
|
||||
to provide basic quantisation formats.
|
||||
|
||||
Args:
|
||||
input_path: Input GGUF file path
|
||||
output_path: Output GGUF file path
|
||||
target_type: Original quantisation type requested
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
# Only handle basic types that we can generate with GGML
|
||||
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
|
||||
|
||||
if target_type in basic_types:
|
||||
logger.info(f"📝 Using GGML numpy implementation for {target_type}")
|
||||
return self.quantise_basic(input_path, output_path, target_type)
|
||||
|
||||
# For K-quants on unsupported architectures, we can't provide a direct equivalent
|
||||
logger.error(f"❌ Cannot quantise {target_type} for unsupported architecture")
|
||||
logger.info("💡 Consider using Q4_0, Q5_0, Q6_0, or Q8_0 instead")
|
||||
return False
|
|
@ -13,6 +13,7 @@ import shutil
|
|||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from helpers.config.quantisation_configs import QUANTISATION_CONFIGS
|
||||
|
@ -488,9 +489,9 @@ class ReadmeGenerator:
|
|||
# If no quantisations succeeded but F16 is available, still add basic tags
|
||||
if (
|
||||
len(our_tags) == 1
|
||||
and "F16" in results
|
||||
and hasattr(results["F16"], "status")
|
||||
and results["F16"].status in {"completed", "uploading"}
|
||||
and QuantisationType.F16 in results
|
||||
and hasattr(results[QuantisationType.F16], "status")
|
||||
and results[QuantisationType.F16].status in {"completed", "uploading"}
|
||||
):
|
||||
our_tags.append("f16")
|
||||
|
||||
|
@ -522,24 +523,36 @@ which replicates Bartowski's quantisation profiles.
|
|||
|---|---|---|
|
||||
"""
|
||||
|
||||
# Add results table - group by layer config patterns
|
||||
supported_types = [
|
||||
# Add results table - properly sorted by precision and type
|
||||
# Order: Q3 K-quants, Q4 basic, Q4 K-quants, Q5 basic, Q5 K-quants, etc.
|
||||
ordered_types = [
|
||||
# Q3 K-quants
|
||||
QuantisationType.Q3_K_M,
|
||||
QuantisationType.Q3_K_L,
|
||||
QuantisationType.Q3_K_XL,
|
||||
# Q4 types
|
||||
QuantisationType.Q4_0, # Basic
|
||||
QuantisationType.Q4_K_M,
|
||||
QuantisationType.Q4_K_L,
|
||||
# Q5 types
|
||||
QuantisationType.Q5_0, # Basic
|
||||
QuantisationType.Q5_K_M,
|
||||
QuantisationType.Q5_K_L,
|
||||
# Q6 types
|
||||
QuantisationType.Q6_0, # Basic
|
||||
QuantisationType.Q6_K,
|
||||
QuantisationType.Q6_K_L,
|
||||
QuantisationType.Q8_0,
|
||||
# Q8 types
|
||||
QuantisationType.Q8_0, # Basic
|
||||
QuantisationType.Q8_K,
|
||||
]
|
||||
|
||||
for quant_type in supported_types:
|
||||
result = results.get(quant_type)
|
||||
if not result:
|
||||
result = type("Result", (), {"status": "planned", "success": False})()
|
||||
for quant_type in ordered_types:
|
||||
result_temp = results.get(quant_type)
|
||||
if result_temp is None:
|
||||
result = SimpleNamespace(status="planned", success=False) # type: ignore[assignment]
|
||||
else:
|
||||
result = result_temp
|
||||
|
||||
config = QUANTISATION_CONFIGS.get(quant_type)
|
||||
status = self._format_status(result, model_source, quant_type, output_repo)
|
||||
|
@ -561,12 +574,12 @@ which replicates Bartowski's quantisation profiles.
|
|||
f16_url = f"https://huggingface.co/{output_repo}/blob/main/{f16_filename}"
|
||||
|
||||
# Get F16 result from results dict (if tracking it)
|
||||
f16_result = results.get("F16")
|
||||
f16_result = results.get(QuantisationType.F16)
|
||||
|
||||
# Get file size
|
||||
f16_size = "-"
|
||||
if f16_result and hasattr(f16_result, "file_size"):
|
||||
f16_size = f16_result.file_size
|
||||
f16_size = f16_result.file_size or "-"
|
||||
elif models_dir:
|
||||
# Try to get from actual file
|
||||
f16_path = models_dir / model_source.model_name / f16_filename
|
||||
|
|
|
@ -9,6 +9,7 @@ from __future__ import annotations
|
|||
|
||||
import gc
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -34,6 +35,7 @@ from helpers.services.huggingface import ReadmeGenerator
|
|||
from helpers.services.imatrix_generator import IMatrixGenerator
|
||||
from helpers.services.llama_cpp import IMatrixHandler
|
||||
from helpers.services.quantisation import HuggingFaceUploader, ModelManager, QuantisationEngine
|
||||
from helpers.utils.rate_limiter import ReadmeRateLimiter
|
||||
from helpers.utils.tensor_mapping import URLParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -65,11 +67,13 @@ class QuantisationOrchestrator:
|
|||
# Computed properties
|
||||
models_dir: Path = field(init=False)
|
||||
model_manager: ModelManager = field(init=False)
|
||||
readme_limiter: ReadmeRateLimiter = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialise computed properties after dataclass construction."""
|
||||
self.models_dir = self.work_dir / "models"
|
||||
self.model_manager = ModelManager(self.models_dir)
|
||||
self.readme_limiter = ReadmeRateLimiter(cooldown_seconds=30.0)
|
||||
|
||||
# Set up signal handlers for graceful exit tracking
|
||||
self._setup_signal_handlers()
|
||||
|
@ -90,6 +94,36 @@ class QuantisationOrchestrator:
|
|||
for sig in [signal.SIGINT, signal.SIGTERM]:
|
||||
signal.signal(sig, signal_handler)
|
||||
|
||||
def _check_architecture_support(self, f16_model_path: Path) -> bool:
|
||||
"""Check if the model architecture is supported by llama.cpp.
|
||||
|
||||
Args:
|
||||
f16_model_path: Path to the F16 GGUF model
|
||||
|
||||
Returns:
|
||||
True if architecture is NOT supported (K-quants should be skipped)
|
||||
"""
|
||||
try:
|
||||
# Try a simple quantization with llama.cpp to check support
|
||||
result = subprocess.run(
|
||||
[
|
||||
".cache/llm-gguf-tools/binaries/llama-quantize",
|
||||
str(f16_model_path),
|
||||
"/dev/null",
|
||||
"Q4_K_M",
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
# Check if it failed due to unknown architecture
|
||||
return bool(result.stderr and "unknown model architecture" in result.stderr.lower())
|
||||
except Exception:
|
||||
# If we can't determine, assume it might work
|
||||
return False
|
||||
|
||||
def get_quantisation_types(self) -> list[QuantisationType]:
|
||||
"""Get the quantisation types to use for this run.
|
||||
|
||||
|
@ -160,8 +194,11 @@ class QuantisationOrchestrator:
|
|||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
raise
|
||||
else:
|
||||
return results
|
||||
finally:
|
||||
# Always flush pending README updates before exiting
|
||||
self.readme_limiter.flush()
|
||||
|
||||
return results
|
||||
|
||||
def _setup_environment(self, url: str) -> tuple[ModelSource, Path, Path | None, str]:
|
||||
"""Setup environment and prepare model for quantisation.
|
||||
|
@ -235,6 +272,24 @@ class QuantisationOrchestrator:
|
|||
types_list = [qt.value for qt in quantisation_types]
|
||||
logger.info(f"Processing {len(quantisation_types)} quantisation types: {types_list}")
|
||||
|
||||
# Check architecture support upfront
|
||||
architecture_unsupported = self._check_architecture_support(f16_model_path)
|
||||
|
||||
if architecture_unsupported:
|
||||
logger.warning("⚠️ Architecture not supported by llama.cpp - K-quants will be skipped")
|
||||
logger.info("💡 Basic types (Q4_0, Q5_0, Q6_0, Q8_0) will still be generated")
|
||||
|
||||
# Pre-mark all K-quants as skipped
|
||||
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
|
||||
for quant_type in quantisation_types:
|
||||
if quant_type.value not in basic_types:
|
||||
results[quant_type] = QuantisationResult(
|
||||
quantisation_type=quant_type,
|
||||
success=False,
|
||||
status="failed",
|
||||
error_message="K-quant requires llama.cpp architecture support",
|
||||
)
|
||||
|
||||
# Track F16 in results for status display (if we converted from SafeTensors)
|
||||
if not model_source.is_gguf_repo:
|
||||
# Get F16 file size
|
||||
|
@ -257,7 +312,7 @@ class QuantisationOrchestrator:
|
|||
"file_size": f16_size,
|
||||
},
|
||||
)()
|
||||
results["F16"] = f16_result
|
||||
results[QuantisationType.F16] = f16_result
|
||||
|
||||
# Process with parallel uploads - quantise sequentially but upload in background
|
||||
upload_futures: list[Any] = []
|
||||
|
@ -265,8 +320,12 @@ class QuantisationOrchestrator:
|
|||
|
||||
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="upload") as upload_executor:
|
||||
# Start F16 upload first if we have one
|
||||
if not model_source.is_gguf_repo and not self.no_upload and "F16" in results:
|
||||
f16_result = results["F16"]
|
||||
if (
|
||||
not model_source.is_gguf_repo
|
||||
and not self.no_upload
|
||||
and QuantisationType.F16 in results
|
||||
):
|
||||
f16_result = results[QuantisationType.F16]
|
||||
if f16_result.file_path and f16_result.file_path.exists():
|
||||
logger.info("Starting parallel upload of F16 GGUF...")
|
||||
f16_result.status = "uploading"
|
||||
|
@ -281,14 +340,10 @@ class QuantisationOrchestrator:
|
|||
)
|
||||
upload_futures.append(upload_future)
|
||||
for i, quant_type in enumerate(quantisation_types, 1):
|
||||
# Skip remaining quantisations if architecture is unsupported
|
||||
if architecture_unsupported:
|
||||
logger.info(f"Skipping {quant_type.value} - architecture not supported")
|
||||
results[quant_type] = QuantisationResult(
|
||||
quantisation_type=quant_type,
|
||||
success=False,
|
||||
status="failed",
|
||||
error_message="Architecture not supported by llama.cpp",
|
||||
# Skip if already marked as failed (e.g., K-quants for unsupported arch)
|
||||
if quant_type in results and results[quant_type].status == "failed":
|
||||
logger.info(
|
||||
f"Skipping {quant_type.value} - {results[quant_type].error_message}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
@ -321,20 +376,27 @@ class QuantisationOrchestrator:
|
|||
== "unsupported_architecture"
|
||||
):
|
||||
logger.warning(
|
||||
"Architecture not supported - skipping remaining quantisations"
|
||||
"⚠️ Architecture not supported by llama.cpp - K-quants will be skipped"
|
||||
)
|
||||
logger.info(
|
||||
"💡 Basic types (Q4_0, Q5_0, Q6_0, Q8_0) will still be generated"
|
||||
)
|
||||
architecture_unsupported = True
|
||||
# Update the current result to also show as skipped
|
||||
result.error_message = "Architecture not supported by llama.cpp"
|
||||
# Update README immediately to show remaining quantizations as skipped
|
||||
# Update README immediately to show remaining K-quants as skipped
|
||||
# But don't mark basic types as failed - they can still use GGML
|
||||
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
|
||||
for remaining_quant_type in quantisation_types[i:]:
|
||||
if remaining_quant_type not in results:
|
||||
results[remaining_quant_type] = QuantisationResult(
|
||||
quantisation_type=remaining_quant_type,
|
||||
success=False,
|
||||
status="failed",
|
||||
error_message="Architecture not supported by llama.cpp",
|
||||
)
|
||||
# Only mark K-quants as failed due to architecture
|
||||
if remaining_quant_type.value not in basic_types:
|
||||
results[remaining_quant_type] = QuantisationResult(
|
||||
quantisation_type=remaining_quant_type,
|
||||
success=False,
|
||||
status="failed",
|
||||
error_message="K-quant requires llama.cpp architecture support",
|
||||
)
|
||||
self._update_readme_status(model_source, results, output_repo)
|
||||
|
||||
# Force cleanup between quantisations
|
||||
|
@ -594,12 +656,27 @@ class QuantisationOrchestrator:
|
|||
results: dict[QuantisationType, QuantisationResult],
|
||||
output_repo: str,
|
||||
) -> None:
|
||||
"""Update README with current quantisation status."""
|
||||
"""Update README with current quantisation status using rate limiting."""
|
||||
if not self.no_upload:
|
||||
updated_readme_path = self.readme_generator.generate(
|
||||
model_source, results, self.models_dir, output_repo
|
||||
# Use rate limiter to batch updates
|
||||
self.readme_limiter.request_update(
|
||||
self._do_readme_update,
|
||||
model_source,
|
||||
results,
|
||||
output_repo,
|
||||
)
|
||||
self.uploader.upload_readme(output_repo, updated_readme_path)
|
||||
|
||||
def _do_readme_update(
|
||||
self,
|
||||
model_source: ModelSource,
|
||||
results: dict[QuantisationType, QuantisationResult],
|
||||
output_repo: str,
|
||||
) -> None:
|
||||
"""Actually perform the README update (called by rate limiter)."""
|
||||
updated_readme_path = self.readme_generator.generate(
|
||||
model_source, results, self.models_dir, output_repo
|
||||
)
|
||||
self.uploader.upload_readme(output_repo, updated_readme_path)
|
||||
|
||||
def _wait_for_uploads(self, upload_futures: list) -> None:
|
||||
"""Wait for all parallel uploads to complete."""
|
||||
|
@ -690,7 +767,7 @@ class QuantisationOrchestrator:
|
|||
output_repo: str,
|
||||
file_path: Path,
|
||||
model_source: ModelSource,
|
||||
results: dict[str, QuantisationResult],
|
||||
results: dict[QuantisationType, QuantisationResult],
|
||||
) -> None:
|
||||
"""Upload F16 file and clean up (runs in background thread)."""
|
||||
try:
|
||||
|
@ -701,7 +778,7 @@ class QuantisationOrchestrator:
|
|||
# Don't delete F16 yet - we still need it for quantisations
|
||||
# It will be deleted in _cleanup_files after all quantisations complete
|
||||
|
||||
results["F16"].status = "completed"
|
||||
results[QuantisationType.F16].status = "completed"
|
||||
updated_readme_path = self.readme_generator.generate(
|
||||
model_source, results, self.models_dir, output_repo
|
||||
)
|
||||
|
@ -710,8 +787,8 @@ class QuantisationOrchestrator:
|
|||
logger.info("[PARALLEL] F16 upload complete")
|
||||
except Exception as e:
|
||||
logger.error(f"[PARALLEL] Failed to upload F16: {e}")
|
||||
results["F16"].status = "failed"
|
||||
results["F16"].error_message = str(e)
|
||||
results[QuantisationType.F16].status = "failed"
|
||||
results[QuantisationType.F16].error_message = str(e)
|
||||
|
||||
try:
|
||||
updated_readme_path = self.readme_generator.generate(
|
||||
|
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
|||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -21,6 +22,7 @@ from helpers.models.quantisation import (
|
|||
QuantisationType,
|
||||
)
|
||||
from helpers.services.filesystem import FilesystemService
|
||||
from helpers.services.ggml_quantise import GGMLQuantiser
|
||||
from helpers.services.gguf import GGUFConverter
|
||||
from helpers.services.llama_cpp import QuantisationExecutor
|
||||
from helpers.utils.config_parser import ConfigParser
|
||||
|
@ -39,12 +41,14 @@ class QuantisationEngine:
|
|||
"""Initialise quantisation engine."""
|
||||
self.fs = FilesystemService()
|
||||
self.executor = QuantisationExecutor()
|
||||
self.ggml_quantiser = GGMLQuantiser()
|
||||
|
||||
def quantise(self, context: QuantisationContext) -> QuantisationResult:
|
||||
"""Perform quantisation using the specified configuration.
|
||||
|
||||
Executes quantisation using direct llama.cpp binary with proper
|
||||
tensor override flags for L and XL variants.
|
||||
tensor override flags for L and XL variants. Falls back to GGML
|
||||
for basic types when architecture is unsupported.
|
||||
|
||||
Returns:
|
||||
QuantisationResult with success status and file information.
|
||||
|
@ -69,8 +73,12 @@ class QuantisationEngine:
|
|||
logger.info(f"📝 Source: {context.f16_model_path}")
|
||||
logger.info(f"📝 Target: {output_path}")
|
||||
|
||||
# Determine if this is a basic type that can use GGML
|
||||
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
|
||||
is_basic_type = context.config.name in basic_types
|
||||
|
||||
try:
|
||||
# Use direct binary execution for quantisation
|
||||
# Try llama.cpp first for all types
|
||||
logger.info("🔧 Using llama.cpp binary for quantisation...")
|
||||
|
||||
success = self.executor.execute_quantisation(
|
||||
|
@ -80,6 +88,23 @@ class QuantisationEngine:
|
|||
if success:
|
||||
return self._create_success_result(context.config.name, output_path, "llama.cpp")
|
||||
|
||||
# Check if this was an architecture error and we can use GGML fallback
|
||||
if (
|
||||
hasattr(self.executor, "last_error")
|
||||
and self.executor.last_error == "unsupported_architecture"
|
||||
and is_basic_type
|
||||
):
|
||||
logger.info("🔄 Architecture unsupported - using GGML implementation...")
|
||||
|
||||
success = self.ggml_quantiser.try_alternative_quantisation(
|
||||
context.f16_model_path, output_path, context.config.name
|
||||
)
|
||||
|
||||
if success:
|
||||
return self._create_success_result(
|
||||
context.config.name, output_path, "GGML numpy"
|
||||
)
|
||||
|
||||
logger.error(f"❌ {context.config.name} quantisation failed")
|
||||
return QuantisationResult(
|
||||
quantisation_type=QuantisationType(context.config.name),
|
||||
|
@ -349,16 +374,17 @@ class ModelManager:
|
|||
)
|
||||
|
||||
# Stream output line by line
|
||||
for line in process.stdout:
|
||||
# Log download progress lines
|
||||
if line.strip():
|
||||
# Check if it's a progress line (contains %)
|
||||
if "%" in line or "Downloading" in line or "Fetching" in line:
|
||||
# Use info level for progress lines
|
||||
logger.info(f" {line.strip()}")
|
||||
else:
|
||||
# Use debug for other output
|
||||
logger.debug(f" {line.strip()}")
|
||||
if process.stdout:
|
||||
for line in process.stdout:
|
||||
# Log download progress lines
|
||||
if line.strip():
|
||||
# Check if it's a progress line (contains %)
|
||||
if "%" in line or "Downloading" in line or "Fetching" in line:
|
||||
# Use info level for progress lines
|
||||
logger.info(f" {line.strip()}")
|
||||
else:
|
||||
# Use debug for other output
|
||||
logger.debug(f" {line.strip()}")
|
||||
|
||||
# Wait for process to complete
|
||||
return_code = process.wait()
|
||||
|
@ -503,6 +529,9 @@ class HuggingFaceUploader:
|
|||
"""
|
||||
logger.info("Uploading README...")
|
||||
|
||||
# Add delay to prevent rate limiting
|
||||
time.sleep(2)
|
||||
|
||||
# First ensure the repository exists
|
||||
self._ensure_repo_exists(output_repo)
|
||||
|
||||
|
@ -576,6 +605,9 @@ class HuggingFaceUploader:
|
|||
"""
|
||||
logger.info(f"Uploading {model_path.name}...")
|
||||
|
||||
# Add delay to prevent rate limiting
|
||||
time.sleep(3)
|
||||
|
||||
# Always use huggingface-cli for model files to ensure xet backend is used
|
||||
try:
|
||||
logger.debug(f"DEBUG: Uploading model file {model_path.name} to {output_repo}")
|
||||
|
|
130
helpers/utils/rate_limiter.py
Normal file
130
helpers/utils/rate_limiter.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
"""Rate limiter for README updates.
|
||||
|
||||
Implements a cooldown mechanism to prevent excessive HuggingFace API calls
|
||||
while ensuring all updates eventually reach the repository.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from helpers.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
class ReadmeRateLimiter:
|
||||
"""Rate limits README updates to prevent API throttling.
|
||||
|
||||
Ensures updates are batched with a minimum interval between API calls,
|
||||
while guaranteeing that pending updates are eventually applied.
|
||||
"""
|
||||
|
||||
def __init__(self, cooldown_seconds: float = 30.0) -> None:
|
||||
"""Initialise rate limiter with specified cooldown period.
|
||||
|
||||
Args:
|
||||
cooldown_seconds: Minimum seconds between updates (default 30).
|
||||
"""
|
||||
self.cooldown_seconds = cooldown_seconds
|
||||
self.last_update_time = 0.0
|
||||
self.pending_update = False
|
||||
self.update_lock = threading.Lock()
|
||||
self.timer: threading.Timer | None = None
|
||||
self.update_func: Callable[..., Any] | None = None
|
||||
self.update_args: tuple[Any, ...] | None = None
|
||||
self.update_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def request_update(
|
||||
self,
|
||||
update_func: Callable[..., Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Request a README update, respecting rate limits.
|
||||
|
||||
Updates are batched during cooldown periods and executed
|
||||
when the cooldown expires.
|
||||
|
||||
Args:
|
||||
update_func: Function to call for the update
|
||||
*args: Positional arguments for update_func
|
||||
**kwargs: Keyword arguments for update_func
|
||||
"""
|
||||
with self.update_lock:
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self.last_update_time
|
||||
|
||||
# Store the latest update request
|
||||
self.update_func = update_func
|
||||
self.update_args = args
|
||||
self.update_kwargs = kwargs
|
||||
|
||||
if time_since_last >= self.cooldown_seconds:
|
||||
# Enough time has passed, update immediately
|
||||
logger.debug(f"README update allowed (last update {time_since_last:.1f}s ago)")
|
||||
self._execute_update()
|
||||
else:
|
||||
# Still in cooldown, schedule for later
|
||||
remaining = self.cooldown_seconds - time_since_last
|
||||
logger.debug(f"README update delayed ({remaining:.1f}s cooldown remaining)")
|
||||
|
||||
if not self.pending_update:
|
||||
# Schedule an update when cooldown expires
|
||||
self.pending_update = True
|
||||
if self.timer:
|
||||
self.timer.cancel()
|
||||
self.timer = threading.Timer(remaining, self._delayed_update)
|
||||
self.timer.start()
|
||||
else:
|
||||
# Update already scheduled, just update the args
|
||||
logger.debug("README update already scheduled, updating with latest data")
|
||||
|
||||
def _execute_update(self) -> None:
|
||||
"""Execute the actual update (must be called with lock held)."""
|
||||
if self.update_func:
|
||||
try:
|
||||
args = self.update_args or ()
|
||||
kwargs = self.update_kwargs or {}
|
||||
self.update_func(*args, **kwargs)
|
||||
self.last_update_time = time.time()
|
||||
logger.debug("README update completed")
|
||||
except Exception as e:
|
||||
logger.error(f"README update failed: {e}")
|
||||
|
||||
self.pending_update = False
|
||||
self.update_func = None
|
||||
self.update_args = None
|
||||
self.update_kwargs = None
|
||||
|
||||
def _delayed_update(self) -> None:
|
||||
"""Execute a delayed update after cooldown expires."""
|
||||
with self.update_lock:
|
||||
if self.pending_update:
|
||||
logger.debug("Executing delayed README update")
|
||||
self._execute_update()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Force any pending updates to execute immediately.
|
||||
|
||||
Called at script end to ensure final state is uploaded.
|
||||
"""
|
||||
with self.update_lock:
|
||||
if self.timer:
|
||||
self.timer.cancel()
|
||||
self.timer = None
|
||||
|
||||
if self.pending_update and self.update_func:
|
||||
logger.info("Flushing pending README update...")
|
||||
# Wait for cooldown if needed
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self.last_update_time
|
||||
if time_since_last < self.cooldown_seconds:
|
||||
wait_time = self.cooldown_seconds - time_since_last
|
||||
logger.info(f"Waiting {wait_time:.1f}s for cooldown before final update...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
self._execute_update()
|
|
@ -70,6 +70,8 @@ skip-magic-trailing-comma = false
|
|||
[tool.ruff.lint]
|
||||
fixable = ["ALL"]
|
||||
ignore = [
|
||||
"ANN002", # type annotation for args
|
||||
"ANN003", # type annotation for kwargs
|
||||
"ANN401", # use of Any type
|
||||
"BLE001", # blind Exception usage
|
||||
"COM812", # missing trailing comma
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue