229 lines
9.4 KiB
Python
229 lines
9.4 KiB
Python
"""Main quantisation orchestrator.
|
|
|
|
Provides the high-level orchestration of the complete quantisation
|
|
workflow, coordinating between various services and modules.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import signal
|
|
import sys
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
from helpers.filesystem import FileCleanup, WorkspaceManager
|
|
from helpers.huggingface import HuggingFaceUploader
|
|
from helpers.llama_cpp import IMatrixGenerator, IMatrixHandler
|
|
from helpers.logger import logger
|
|
from helpers.models.quantisation import QuantisationResult, QuantisationType
|
|
from helpers.quantisation.engine import QuantisationEngine
|
|
from helpers.quantisation.executor import QuantisationExecutor
|
|
from helpers.quantisation.model_manager import ModelManager
|
|
from helpers.quantisation.profile_manager import ProfileManager
|
|
from helpers.quantisation.progress import ProgressReporter
|
|
from helpers.readme import ReadmeGenerator
|
|
from helpers.utils.rate_limiter import ReadmeRateLimiter
|
|
from helpers.utils.tensor_mapping import URLParser
|
|
|
|
if TYPE_CHECKING:
|
|
from types import FrameType
|
|
|
|
from helpers.models.quantisation import ModelSource
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class QuantisationOrchestrator:
|
|
"""Orchestrates the complete quantisation workflow.
|
|
|
|
Thin coordinator that delegates to specialised services for
|
|
each aspect of the quantisation workflow.
|
|
"""
|
|
|
|
work_dir: Path = field(default_factory=lambda: Path.cwd() / "quantisation_work")
|
|
use_imatrix: bool = True
|
|
no_upload: bool = False
|
|
custom_profiles: list[str] | None = None
|
|
|
|
# Service dependencies
|
|
url_parser: URLParser = field(default_factory=URLParser)
|
|
workspace_manager: WorkspaceManager = field(init=False)
|
|
model_manager: ModelManager = field(init=False)
|
|
profile_manager: ProfileManager = field(default_factory=ProfileManager)
|
|
progress_reporter: ProgressReporter = field(default_factory=ProgressReporter)
|
|
quantisation_executor: QuantisationExecutor = field(init=False)
|
|
imatrix_handler: IMatrixHandler = field(default_factory=IMatrixHandler)
|
|
imatrix_generator: IMatrixGenerator = field(default_factory=IMatrixGenerator)
|
|
readme_generator: ReadmeGenerator = field(default_factory=ReadmeGenerator)
|
|
uploader: HuggingFaceUploader = field(default_factory=HuggingFaceUploader)
|
|
file_cleanup: FileCleanup = field(default_factory=FileCleanup)
|
|
readme_limiter: ReadmeRateLimiter = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Initialise computed properties after dataclass construction."""
|
|
self.workspace_manager = WorkspaceManager(self.work_dir)
|
|
self.model_manager = ModelManager(self.workspace_manager.models_dir)
|
|
self.readme_limiter = ReadmeRateLimiter(cooldown_seconds=30.0)
|
|
|
|
# Create executor with dependencies
|
|
self.quantisation_executor = QuantisationExecutor(
|
|
quantisation_engine=QuantisationEngine(),
|
|
uploader=self.uploader,
|
|
readme_generator=self.readme_generator,
|
|
file_cleanup=self.file_cleanup,
|
|
no_upload=self.no_upload,
|
|
)
|
|
|
|
# Set up signal handlers
|
|
self._setup_signal_handlers()
|
|
|
|
def _setup_signal_handlers(self) -> None:
|
|
"""Set up signal handlers to catch unexpected exits."""
|
|
|
|
def signal_handler(signum: int, frame: FrameType | None) -> None:
|
|
logger.error(f"❌ Received signal {signum} ({signal.Signals(signum).name})")
|
|
logger.error("Stack trace at signal:")
|
|
if frame:
|
|
for line in traceback.format_stack(frame):
|
|
logger.error(f" {line.strip()}")
|
|
logger.error("Exiting due to signal")
|
|
sys.exit(1)
|
|
|
|
# Handle common termination signals
|
|
for sig in [signal.SIGINT, signal.SIGTERM]:
|
|
signal.signal(sig, signal_handler)
|
|
|
|
def quantise(self, url: str) -> dict[QuantisationType, QuantisationResult]:
|
|
"""Main quantisation workflow orchestrating model processing from URL to upload.
|
|
|
|
Coordinates the complete quantisation process from URL parsing through
|
|
model downloading, quantisation execution, and upload to HuggingFace.
|
|
Handles architecture compatibility and provides comprehensive error handling.
|
|
|
|
Returns:
|
|
Dictionary of quantisation results by type.
|
|
|
|
Raises:
|
|
KeyboardInterrupt: If the user interrupts the quantisation process.
|
|
"""
|
|
logger.info("Starting Bartowski quantisation process...")
|
|
logger.debug(f"DEBUG: Input URL: {url}")
|
|
logger.debug(f"DEBUG: Working directory: {self.work_dir}")
|
|
logger.debug(f"DEBUG: Use imatrix: {self.use_imatrix}")
|
|
logger.debug(f"DEBUG: No upload: {self.no_upload}")
|
|
logger.debug(f"DEBUG: Custom profiles: {self.custom_profiles}")
|
|
|
|
try:
|
|
# Setup and preparation
|
|
model_source, f16_model_path, imatrix_path, output_repo = self._setup_environment(url)
|
|
|
|
# Create initial repository
|
|
self._create_initial_repository(model_source, output_repo)
|
|
|
|
# Get quantisation types
|
|
quantisation_types = self.profile_manager.get_quantisation_types(self.custom_profiles)
|
|
|
|
# Filter by architecture if needed
|
|
supported_types, unsupported_types = self.profile_manager.filter_by_architecture(
|
|
quantisation_types, f16_model_path
|
|
)
|
|
|
|
# Pre-mark unsupported types
|
|
results: dict[QuantisationType, QuantisationResult] = {}
|
|
for quant_type in unsupported_types:
|
|
results[quant_type] = QuantisationResult(
|
|
quantisation_type=quant_type,
|
|
success=False,
|
|
status="failed",
|
|
error_message="K-quant requires llama.cpp architecture support",
|
|
)
|
|
|
|
# Execute quantisations
|
|
execution_results = self.quantisation_executor.execute_quantisations(
|
|
model_source,
|
|
f16_model_path,
|
|
imatrix_path,
|
|
output_repo,
|
|
supported_types,
|
|
self.workspace_manager.models_dir,
|
|
)
|
|
results.update(execution_results)
|
|
|
|
# Cleanup
|
|
self.file_cleanup.cleanup_files(
|
|
f16_model_path, model_source, self.workspace_manager.models_dir
|
|
)
|
|
|
|
# Print summary
|
|
self.progress_reporter.print_completion_summary(model_source, results, output_repo)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.error("❌ Process interrupted by user (Ctrl+C)")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ Critical error in quantisation workflow: {e}")
|
|
logger.error("Full traceback:")
|
|
for line in traceback.format_exc().splitlines():
|
|
logger.error(f" {line}")
|
|
raise
|
|
finally:
|
|
# Always flush pending README updates before exiting
|
|
self.readme_limiter.flush()
|
|
|
|
return results
|
|
|
|
def _setup_environment(self, url: str) -> tuple[ModelSource, Path, Path | None, str]:
|
|
"""Setup environment and prepare model for quantisation.
|
|
|
|
Returns:
|
|
Tuple of (model_source, f16_model_path, imatrix_path, output_repo).
|
|
"""
|
|
model_source = self.url_parser.parse(url)
|
|
self.progress_reporter.print_model_info(
|
|
model_source, self.uploader.get_username(), str(self.work_dir)
|
|
)
|
|
|
|
f16_model_path = self.model_manager.prepare_model(model_source)
|
|
|
|
output_repo = (
|
|
f"{self.uploader.get_username()}/"
|
|
f"{model_source.original_author}-{model_source.model_name}-GGUF"
|
|
)
|
|
|
|
imatrix_path = None
|
|
if self.use_imatrix:
|
|
logger.info("Checking for importance matrix (imatrix)...")
|
|
model_dir = self.workspace_manager.get_model_dir(model_source.model_name)
|
|
imatrix_path = self.imatrix_handler.find_imatrix(model_dir)
|
|
|
|
# If no imatrix found, offer to generate or provide one
|
|
if not imatrix_path:
|
|
# First offer to generate
|
|
imatrix_path = self.imatrix_generator.prompt_for_generation(
|
|
model_source, model_dir, f16_model_path
|
|
)
|
|
|
|
# If generation was skipped, offer to provide existing one
|
|
if not imatrix_path:
|
|
imatrix_path = self.imatrix_handler.prompt_for_user_imatrix(model_dir)
|
|
|
|
return model_source, f16_model_path, imatrix_path, output_repo
|
|
|
|
def _create_initial_repository(self, model_source: ModelSource, output_repo: str) -> None:
|
|
"""Create initial repository with planned quantisations."""
|
|
logger.info("Creating initial README with planned quantisations...")
|
|
quantisation_types = self.profile_manager.get_quantisation_types(self.custom_profiles)
|
|
planned_results = {
|
|
qt: QuantisationResult(quantisation_type=qt, success=False, status="planned")
|
|
for qt in quantisation_types
|
|
}
|
|
readme_path = self.readme_generator.generate(
|
|
model_source, planned_results, self.workspace_manager.models_dir, output_repo
|
|
)
|
|
|
|
if not self.no_upload:
|
|
logger.info("Creating repository with planned quantisations...")
|
|
self.uploader.upload_readme(output_repo, readme_path)
|
|
else:
|
|
logger.info("Skipping repository creation (--no-upload specified)")
|