132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
"""Quantisation profile management.
|
|
|
|
Manages selection and validation of quantisation types based on
|
|
user preferences, architecture support, and configuration.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from helpers.config.quantisation_configs import (
|
|
DEFAULT_QUANTISATION_TYPES,
|
|
SUPPORTED_QUANTISATION_TYPES,
|
|
)
|
|
from helpers.llama_cpp.architecture import ArchitectureDetector
|
|
from helpers.logger import logger
|
|
from helpers.models.quantisation import QuantisationType
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
|
|
class ProfileManager:
|
|
"""Manages quantisation profiles and type selection.
|
|
|
|
Handles selection of quantisation types based on custom profiles,
|
|
architecture support, and fallback to defaults.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_quantisation_types(
|
|
custom_profiles: list[str] | None = None,
|
|
) -> list[QuantisationType]:
|
|
"""Get the quantisation types to use for this run.
|
|
|
|
Determines which quantisation types should be processed based on
|
|
custom profiles provided by the user, or falls back to default
|
|
configurations if no custom profiles are specified.
|
|
|
|
Returns:
|
|
List of QuantisationType enums to process.
|
|
"""
|
|
if custom_profiles:
|
|
return ProfileManager._parse_custom_profiles(custom_profiles)
|
|
return DEFAULT_QUANTISATION_TYPES
|
|
|
|
@staticmethod
|
|
def _parse_custom_profiles(profile_strings: list[str]) -> list[QuantisationType]:
|
|
"""Parse custom profile strings to QuantisationType enums.
|
|
|
|
Validates and converts user-provided profile strings into proper
|
|
QuantisationType enumerations, filtering out invalid or unsupported
|
|
types whilst logging warnings for problematic entries.
|
|
|
|
Returns:
|
|
List of valid QuantisationType enums.
|
|
"""
|
|
result = []
|
|
for profile_str in profile_strings:
|
|
try:
|
|
profile = QuantisationType(profile_str.upper())
|
|
if profile in SUPPORTED_QUANTISATION_TYPES:
|
|
result.append(profile)
|
|
else:
|
|
logger.warning(f"Profile {profile_str} is not supported, skipping")
|
|
except ValueError:
|
|
logger.warning(f"Invalid profile {profile_str}, skipping")
|
|
|
|
# Fall back to defaults if no valid profiles
|
|
return result or DEFAULT_QUANTISATION_TYPES
|
|
|
|
@staticmethod
|
|
def filter_by_architecture(
|
|
quantisation_types: list[QuantisationType],
|
|
f16_model_path: Path,
|
|
) -> tuple[list[QuantisationType], list[QuantisationType]]:
|
|
"""Filter quantisation types based on architecture support.
|
|
|
|
Analyses the F16 GGUF model to determine architecture compatibility
|
|
and filters the requested quantisation types accordingly. Separates
|
|
supported types from unsupported ones, especially filtering K-quants
|
|
for architectures not supported by llama.cpp.
|
|
|
|
Returns:
|
|
Tuple of (supported_types, unsupported_types).
|
|
"""
|
|
if not ArchitectureDetector.check_architecture_support(f16_model_path):
|
|
# Architecture not supported - filter out K-quants
|
|
basic_types = ["Q4_0", "Q5_0", "Q6_0", "Q8_0"]
|
|
supported = []
|
|
unsupported = []
|
|
|
|
for quant_type in quantisation_types:
|
|
if quant_type.value in basic_types:
|
|
supported.append(quant_type)
|
|
else:
|
|
unsupported.append(quant_type)
|
|
|
|
if unsupported:
|
|
logger.warning(
|
|
"⚠️ Architecture not supported by llama.cpp - K-quants will be skipped"
|
|
)
|
|
logger.info("💡 Basic types (Q4_0, Q5_0, Q6_0, Q8_0) will still be generated")
|
|
|
|
return supported, unsupported
|
|
|
|
# All types supported
|
|
return quantisation_types, []
|
|
|
|
@staticmethod
|
|
def validate_profiles(profiles: list[str]) -> list[str]:
|
|
"""Validate a list of profile strings.
|
|
|
|
Checks each profile string to ensure it corresponds to a valid
|
|
and supported quantisation type, logging warnings for invalid
|
|
entries whilst returning only the valid profile strings.
|
|
|
|
Returns:
|
|
List of valid profile strings.
|
|
"""
|
|
valid = []
|
|
for profile in profiles:
|
|
try:
|
|
quant_type = QuantisationType(profile.upper())
|
|
if quant_type in SUPPORTED_QUANTISATION_TYPES:
|
|
valid.append(profile)
|
|
else:
|
|
logger.warning(f"Profile {profile} exists but is not supported")
|
|
except ValueError:
|
|
logger.warning(f"Profile {profile} is not a valid quantisation type")
|
|
|
|
return valid
|