llm-gguf-tools/helpers/quantisation/orchestrator.py
2025-08-09 17:16:02 +01:00

229 lines
9.4 KiB
Python

"""Main quantisation orchestrator.
Provides the high-level orchestration of the complete quantisation
workflow, coordinating between various services and modules.
"""
from __future__ import annotations
import signal
import sys
import traceback
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from helpers.filesystem import FileCleanup, WorkspaceManager
from helpers.huggingface import HuggingFaceUploader
from helpers.llama_cpp import IMatrixGenerator, IMatrixHandler
from helpers.logger import logger
from helpers.models.quantisation import QuantisationResult, QuantisationType
from helpers.quantisation.engine import QuantisationEngine
from helpers.quantisation.executor import QuantisationExecutor
from helpers.quantisation.model_manager import ModelManager
from helpers.quantisation.profile_manager import ProfileManager
from helpers.quantisation.progress import ProgressReporter
from helpers.readme import ReadmeGenerator
from helpers.utils.rate_limiter import ReadmeRateLimiter
from helpers.utils.tensor_mapping import URLParser
if TYPE_CHECKING:
from types import FrameType
from helpers.models.quantisation import ModelSource
@dataclass(slots=True)
class QuantisationOrchestrator:
"""Orchestrates the complete quantisation workflow.
Thin coordinator that delegates to specialised services for
each aspect of the quantisation workflow.
"""
work_dir: Path = field(default_factory=lambda: Path.cwd() / "quantisation_work")
use_imatrix: bool = True
no_upload: bool = False
custom_profiles: list[str] | None = None
# Service dependencies
url_parser: URLParser = field(default_factory=URLParser)
workspace_manager: WorkspaceManager = field(init=False)
model_manager: ModelManager = field(init=False)
profile_manager: ProfileManager = field(default_factory=ProfileManager)
progress_reporter: ProgressReporter = field(default_factory=ProgressReporter)
quantisation_executor: QuantisationExecutor = field(init=False)
imatrix_handler: IMatrixHandler = field(default_factory=IMatrixHandler)
imatrix_generator: IMatrixGenerator = field(default_factory=IMatrixGenerator)
readme_generator: ReadmeGenerator = field(default_factory=ReadmeGenerator)
uploader: HuggingFaceUploader = field(default_factory=HuggingFaceUploader)
file_cleanup: FileCleanup = field(default_factory=FileCleanup)
readme_limiter: ReadmeRateLimiter = field(init=False)
def __post_init__(self) -> None:
"""Initialise computed properties after dataclass construction."""
self.workspace_manager = WorkspaceManager(self.work_dir)
self.model_manager = ModelManager(self.workspace_manager.models_dir)
self.readme_limiter = ReadmeRateLimiter(cooldown_seconds=30.0)
# Create executor with dependencies
self.quantisation_executor = QuantisationExecutor(
quantisation_engine=QuantisationEngine(),
uploader=self.uploader,
readme_generator=self.readme_generator,
file_cleanup=self.file_cleanup,
no_upload=self.no_upload,
)
# Set up signal handlers
self._setup_signal_handlers()
def _setup_signal_handlers(self) -> None:
"""Set up signal handlers to catch unexpected exits."""
def signal_handler(signum: int, frame: FrameType | None) -> None:
logger.error(f"❌ Received signal {signum} ({signal.Signals(signum).name})")
logger.error("Stack trace at signal:")
if frame:
for line in traceback.format_stack(frame):
logger.error(f" {line.strip()}")
logger.error("Exiting due to signal")
sys.exit(1)
# Handle common termination signals
for sig in [signal.SIGINT, signal.SIGTERM]:
signal.signal(sig, signal_handler)
def quantise(self, url: str) -> dict[QuantisationType, QuantisationResult]:
"""Main quantisation workflow orchestrating model processing from URL to upload.
Coordinates the complete quantisation process from URL parsing through
model downloading, quantisation execution, and upload to HuggingFace.
Handles architecture compatibility and provides comprehensive error handling.
Returns:
Dictionary of quantisation results by type.
Raises:
KeyboardInterrupt: If the user interrupts the quantisation process.
"""
logger.info("Starting Bartowski quantisation process...")
logger.debug(f"DEBUG: Input URL: {url}")
logger.debug(f"DEBUG: Working directory: {self.work_dir}")
logger.debug(f"DEBUG: Use imatrix: {self.use_imatrix}")
logger.debug(f"DEBUG: No upload: {self.no_upload}")
logger.debug(f"DEBUG: Custom profiles: {self.custom_profiles}")
try:
# Setup and preparation
model_source, f16_model_path, imatrix_path, output_repo = self._setup_environment(url)
# Create initial repository
self._create_initial_repository(model_source, output_repo)
# Get quantisation types
quantisation_types = self.profile_manager.get_quantisation_types(self.custom_profiles)
# Filter by architecture if needed
supported_types, unsupported_types = self.profile_manager.filter_by_architecture(
quantisation_types, f16_model_path
)
# Pre-mark unsupported types
results: dict[QuantisationType, QuantisationResult] = {}
for quant_type in unsupported_types:
results[quant_type] = QuantisationResult(
quantisation_type=quant_type,
success=False,
status="failed",
error_message="K-quant requires llama.cpp architecture support",
)
# Execute quantisations
execution_results = self.quantisation_executor.execute_quantisations(
model_source,
f16_model_path,
imatrix_path,
output_repo,
supported_types,
self.workspace_manager.models_dir,
)
results.update(execution_results)
# Cleanup
self.file_cleanup.cleanup_files(
f16_model_path, model_source, self.workspace_manager.models_dir
)
# Print summary
self.progress_reporter.print_completion_summary(model_source, results, output_repo)
except KeyboardInterrupt:
logger.error("❌ Process interrupted by user (Ctrl+C)")
raise
except Exception as e:
logger.error(f"❌ Critical error in quantisation workflow: {e}")
logger.error("Full traceback:")
for line in traceback.format_exc().splitlines():
logger.error(f" {line}")
raise
finally:
# Always flush pending README updates before exiting
self.readme_limiter.flush()
return results
def _setup_environment(self, url: str) -> tuple[ModelSource, Path, Path | None, str]:
"""Setup environment and prepare model for quantisation.
Returns:
Tuple of (model_source, f16_model_path, imatrix_path, output_repo).
"""
model_source = self.url_parser.parse(url)
self.progress_reporter.print_model_info(
model_source, self.uploader.get_username(), str(self.work_dir)
)
f16_model_path = self.model_manager.prepare_model(model_source)
output_repo = (
f"{self.uploader.get_username()}/"
f"{model_source.original_author}-{model_source.model_name}-GGUF"
)
imatrix_path = None
if self.use_imatrix:
logger.info("Checking for importance matrix (imatrix)...")
model_dir = self.workspace_manager.get_model_dir(model_source.model_name)
imatrix_path = self.imatrix_handler.find_imatrix(model_dir)
# If no imatrix found, offer to generate or provide one
if not imatrix_path:
# First offer to generate
imatrix_path = self.imatrix_generator.prompt_for_generation(
model_source, model_dir, f16_model_path
)
# If generation was skipped, offer to provide existing one
if not imatrix_path:
imatrix_path = self.imatrix_handler.prompt_for_user_imatrix(model_dir)
return model_source, f16_model_path, imatrix_path, output_repo
def _create_initial_repository(self, model_source: ModelSource, output_repo: str) -> None:
"""Create initial repository with planned quantisations."""
logger.info("Creating initial README with planned quantisations...")
quantisation_types = self.profile_manager.get_quantisation_types(self.custom_profiles)
planned_results = {
qt: QuantisationResult(quantisation_type=qt, success=False, status="planned")
for qt in quantisation_types
}
readme_path = self.readme_generator.generate(
model_source, planned_results, self.workspace_manager.models_dir, output_repo
)
if not self.no_upload:
logger.info("Creating repository with planned quantisations...")
self.uploader.upload_readme(output_repo, readme_path)
else:
logger.info("Skipping repository creation (--no-upload specified)")