Switch to llama-cpp-python
This commit is contained in:
parent
ef7df1a8c3
commit
d937f2d5fa
25 changed files with 2957 additions and 1181 deletions
|
@ -4,17 +4,3 @@ Provides high-level service interfaces for interacting with external systems
|
|||
including HuggingFace, llama.cpp, and filesystem operations. Uses UK English
|
||||
spelling conventions throughout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.services.filesystem import FilesystemService
|
||||
from helpers.services.huggingface import HuggingFaceService, ReadmeGenerator
|
||||
from helpers.services.llama_cpp import EnvironmentManager, IMatrixGenerator
|
||||
|
||||
__all__ = [
|
||||
"EnvironmentManager",
|
||||
"FilesystemService",
|
||||
"HuggingFaceService",
|
||||
"IMatrixGenerator",
|
||||
"ReadmeGenerator",
|
||||
]
|
||||
|
|
|
@ -7,7 +7,8 @@ Uses UK English spelling conventions throughout.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import gc
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
@ -17,6 +18,25 @@ from helpers.logger import logger
|
|||
from helpers.services.filesystem import FilesystemService
|
||||
from helpers.utils.config_parser import ConfigParser
|
||||
|
||||
|
||||
class VisionConfig(Protocol):
|
||||
"""Protocol for vision model configuration."""
|
||||
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
intermediate_size: int
|
||||
patch_size: int
|
||||
spatial_merge_size: int
|
||||
|
||||
|
||||
class TensorMapper(Protocol):
|
||||
"""Protocol for tensor name mapping."""
|
||||
|
||||
def map_tensor_name(self, name: str) -> str | None:
|
||||
"""Map a tensor name to its GGUF equivalent."""
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -71,7 +91,7 @@ class GGUFWriter:
|
|||
|
||||
logger.info(f"Added metadata: {params.block_count} layers, {params.context_length} context")
|
||||
|
||||
def add_vision_metadata(self, vision_config: Any) -> None:
|
||||
def add_vision_metadata(self, vision_config: VisionConfig | None) -> None:
|
||||
"""Add vision model parameters to GGUF metadata.
|
||||
|
||||
Configures vision-specific parameters for multimodal models including
|
||||
|
@ -141,7 +161,7 @@ class GGUFConverter:
|
|||
output_path: Path,
|
||||
model_config: ModelConfig,
|
||||
architecture: str,
|
||||
tensor_mapper: Any,
|
||||
tensor_mapper: TensorMapper,
|
||||
) -> bool:
|
||||
"""Convert SafeTensors model to GGUF format.
|
||||
|
||||
|
@ -172,7 +192,7 @@ class GGUFConverter:
|
|||
for tensor_file in tensor_files:
|
||||
logger.info(f"Loading {tensor_file.name}...")
|
||||
with safe_open(tensor_file, framework="pt") as f:
|
||||
for tensor_name in f:
|
||||
for tensor_name in f.keys(): # noqa: SIM118
|
||||
tensor_data = f.get_tensor(tensor_name)
|
||||
|
||||
# Convert BFloat16 to Float32
|
||||
|
@ -191,6 +211,12 @@ class GGUFConverter:
|
|||
if tensor_count % 100 == 0:
|
||||
logger.info(f" Processed {tensor_count} tensors...")
|
||||
|
||||
# Free memory after processing each tensor
|
||||
del tensor_data
|
||||
|
||||
# Force garbage collection after processing each file
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Total tensors processed: {tensor_count}")
|
||||
|
||||
# Add tokeniser
|
||||
|
|
|
@ -8,17 +8,22 @@ spelling conventions throughout.
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from helpers.config.quantisation_configs import QUANTISATION_CONFIGS
|
||||
from helpers.logger import logger
|
||||
from helpers.models.quantisation import QuantisationType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from helpers.models.quantisation import ModelSource, QuantisationResult
|
||||
|
||||
# Constants for file size formatting
|
||||
GIBIBYTE = 1024**3
|
||||
|
||||
|
||||
class HuggingFaceService:
|
||||
"""Manages HuggingFace repository operations.
|
||||
|
@ -76,7 +81,7 @@ class HuggingFaceService:
|
|||
if include_pattern:
|
||||
cmd.extend(["--include", include_pattern])
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info("Download complete")
|
||||
|
||||
@staticmethod
|
||||
|
@ -89,8 +94,8 @@ class HuggingFaceService:
|
|||
"""Upload a file to HuggingFace repository.
|
||||
|
||||
Uploads a single file to the specified repository path. Can create
|
||||
the repository if it doesn't exist. Handles repository creation conflicts
|
||||
gracefully by retrying without the create flag when needed.
|
||||
the repository if it doesn't exist. Uses git directly when possible
|
||||
to avoid automatic PR creation.
|
||||
|
||||
Raises:
|
||||
CalledProcessError: If upload fails.
|
||||
|
@ -98,12 +103,25 @@ class HuggingFaceService:
|
|||
repo_path = repo_path or local_path.name
|
||||
logger.info(f"Uploading {local_path.name} to {repo_id}/{repo_path}")
|
||||
|
||||
# Try git-based upload first to avoid PR creation
|
||||
if HuggingFaceService._try_git_upload(
|
||||
repo_id, local_path, repo_path, create_repo=create_repo
|
||||
):
|
||||
logger.info(f"Uploaded {repo_path} via git")
|
||||
return
|
||||
|
||||
# Fallback to huggingface-cli
|
||||
logger.info("Git upload failed, trying huggingface-cli...")
|
||||
cmd = [
|
||||
"huggingface-cli",
|
||||
"upload",
|
||||
repo_id,
|
||||
str(local_path),
|
||||
repo_path,
|
||||
"--revision",
|
||||
"main", # Explicitly push to main branch
|
||||
"--commit-message",
|
||||
f"Add {repo_path}",
|
||||
]
|
||||
|
||||
if create_repo:
|
||||
|
@ -116,11 +134,99 @@ class HuggingFaceService:
|
|||
if create_repo:
|
||||
# Repository might already exist, retry without --create
|
||||
cmd = cmd[:-1] # Remove --create flag
|
||||
subprocess.run(cmd, check=True)
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info(f"Updated {repo_path}")
|
||||
else:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _try_git_upload(
|
||||
repo_id: str,
|
||||
local_path: Path,
|
||||
repo_path: str,
|
||||
*,
|
||||
create_repo: bool = False,
|
||||
) -> bool:
|
||||
"""Try to upload file using git directly to avoid PR creation.
|
||||
|
||||
Returns:
|
||||
bool: True if upload successful, False if should fallback to CLI.
|
||||
"""
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
repo_url = f"https://huggingface.co/{repo_id}"
|
||||
|
||||
# Clone repository
|
||||
logger.info(f"Cloning {repo_url}...")
|
||||
result = subprocess.run(
|
||||
["git", "clone", repo_url, str(temp_path / "repo")],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
if create_repo:
|
||||
# Repository doesn't exist, let huggingface-cli handle creation
|
||||
return False
|
||||
logger.warning(f"Clone failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
repo_dir = temp_path / "repo"
|
||||
target_file = repo_dir / repo_path
|
||||
|
||||
# Ensure target directory exists
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy file
|
||||
shutil.copy2(local_path, target_file)
|
||||
|
||||
# Check if there are any changes
|
||||
status_result = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
cwd=repo_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not status_result.stdout.strip():
|
||||
logger.info(f"No changes detected for {repo_path}, file already up-to-date")
|
||||
return True # File is already up-to-date, no need to push
|
||||
|
||||
# Git add, commit, push
|
||||
subprocess.run(
|
||||
["git", "add", repo_path],
|
||||
cwd=repo_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", f"Update {repo_path}"],
|
||||
cwd=repo_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "push"],
|
||||
cwd=repo_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Git upload failed: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Git upload error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ReadmeGenerator:
|
||||
"""Generates README files for quantised models.
|
||||
|
@ -173,14 +279,45 @@ class ReadmeGenerator:
|
|||
"""
|
||||
content = {"readme": "", "licence": "apache-2.0", "tags": "", "frontmatter": ""}
|
||||
|
||||
# Try local file first
|
||||
# Check for preserved original README first
|
||||
original_readme_path = model_dir / "README.original.md"
|
||||
readme_path = model_dir / "README.md"
|
||||
if readme_path.exists():
|
||||
content["readme"] = readme_path.read_text(encoding="utf-8")
|
||||
logger.info(f"Found original README ({len(content['readme'])} characters)")
|
||||
|
||||
if original_readme_path.exists():
|
||||
# Use the preserved original
|
||||
content["readme"] = original_readme_path.read_text(encoding="utf-8")
|
||||
logger.info(f"Found preserved original README ({len(content['readme'])} characters)")
|
||||
elif readme_path.exists():
|
||||
# First time - preserve the original and use it
|
||||
readme_content = readme_path.read_text(encoding="utf-8")
|
||||
|
||||
# Check if this is already our generated README
|
||||
if (
|
||||
f"{model_source.original_author}-{model_source.model_name}-GGUF"
|
||||
not in readme_content
|
||||
):
|
||||
# This is the original - preserve it
|
||||
original_readme_path.write_text(readme_content, encoding="utf-8")
|
||||
content["readme"] = readme_content
|
||||
readme_len = len(content["readme"])
|
||||
logger.info(
|
||||
f"Preserved original README as README.original.md ({readme_len} characters)"
|
||||
)
|
||||
else:
|
||||
# This is our generated README, need to download the original
|
||||
logger.info("Found generated README, downloading original from source")
|
||||
content = self._download_readme(model_source)
|
||||
# Save the downloaded original for future use
|
||||
if content["readme"]:
|
||||
original_readme_path.write_text(content["readme"], encoding="utf-8")
|
||||
logger.info("Preserved downloaded original README as README.original.md")
|
||||
else:
|
||||
# Download separately
|
||||
# No local README - download from source
|
||||
content = self._download_readme(model_source)
|
||||
# Save the downloaded original for future use
|
||||
if content["readme"]:
|
||||
original_readme_path.write_text(content["readme"], encoding="utf-8")
|
||||
logger.info("Preserved downloaded original README as README.original.md")
|
||||
|
||||
# Parse frontmatter if present
|
||||
if content["readme"].startswith("---\n"):
|
||||
|
@ -303,10 +440,16 @@ class ReadmeGenerator:
|
|||
our_tags = [
|
||||
"quantised",
|
||||
"gguf",
|
||||
"q3_k_m",
|
||||
"q3_k_l",
|
||||
"q3_k_xl",
|
||||
"q4_k_m",
|
||||
"q4_k_l",
|
||||
"q4_k_xl",
|
||||
"q4_k_xxl",
|
||||
"q5_k_m",
|
||||
"q5_k_l",
|
||||
"q6_k",
|
||||
"q6_k_l",
|
||||
"q8_0",
|
||||
"bartowski-method",
|
||||
]
|
||||
original_tags = original_content["tags"].split(",") if original_content["tags"] else []
|
||||
|
@ -329,62 +472,78 @@ tags:
|
|||
hf_url = f"https://huggingface.co/{model_source.source_model}"
|
||||
content = f"""# {model_source.original_author}-{model_source.model_name}-GGUF
|
||||
|
||||
GGUF quantisations of [{model_source.source_model}]({hf_url}) using Bartowski's method.
|
||||
GGUF quantisations of [{model_source.source_model}]({hf_url}) using
|
||||
[Bartowski](https://huggingface.co/bartowski)'s method. Created with [llm-gguf-tools](https://git.tomfos.tr/tom/llm-gguf-tools)
|
||||
which replicates Bartowski's quantisation profiles.
|
||||
|
||||
| Quantisation | Embeddings/Output | Attention | Feed-Forward | Status |
|
||||
|--------------|-------------------|-----------|--------------|--------|
|
||||
| Variant | Configuration | File Size | Status |
|
||||
|---|---|---|---|
|
||||
"""
|
||||
|
||||
# Add results table
|
||||
for quant_type in [
|
||||
# Add results table - group by layer config patterns
|
||||
supported_types = [
|
||||
QuantisationType.Q3_K_M,
|
||||
QuantisationType.Q3_K_L,
|
||||
QuantisationType.Q3_K_XL,
|
||||
QuantisationType.Q4_K_M,
|
||||
QuantisationType.Q4_K_L,
|
||||
QuantisationType.Q4_K_XL,
|
||||
QuantisationType.Q4_K_XXL,
|
||||
]:
|
||||
QuantisationType.Q5_K_M,
|
||||
QuantisationType.Q5_K_L,
|
||||
QuantisationType.Q6_K,
|
||||
QuantisationType.Q6_K_L,
|
||||
QuantisationType.Q8_0,
|
||||
]
|
||||
|
||||
for quant_type in supported_types:
|
||||
result = results.get(quant_type)
|
||||
if not result:
|
||||
result = type("Result", (), {"status": "planned", "success": False})()
|
||||
|
||||
layers = self._get_layers_config(quant_type)
|
||||
config = QUANTISATION_CONFIGS.get(quant_type)
|
||||
file_size = self._format_file_size(result)
|
||||
status = self._format_status(result, model_source, quant_type, output_repo)
|
||||
|
||||
content += (
|
||||
f"| {quant_type.value} | {layers['embeddings']} | "
|
||||
f"{layers['attention']} | {layers['ffn']} | {status} |\n"
|
||||
)
|
||||
# Get configuration description from the config itself
|
||||
config_desc = config.get_compact_config(QUANTISATION_CONFIGS) if config else f"{quant_type} all layers"
|
||||
|
||||
content += "\n---\n\n"
|
||||
content += f"| **{quant_type.value}** | {config_desc} | {file_size} | {status} |\n"
|
||||
|
||||
content += """
|
||||
|
||||
**Key:** `E` = Embeddings, `O` = Output, `A` = Attention, `F` = FFN
|
||||
|
||||
See [Bartowski Analysis](https://git.tomfos.tr/tom/llm-gguf-tools/src/branch/main/docs/bartowski_analysis.md)
|
||||
for detailed quantisation strategies and [Documentation](https://git.tomfos.tr/tom/llm-gguf-tools/src/branch/main/docs/)
|
||||
for more on the tools and methods I use.
|
||||
|
||||
"""
|
||||
|
||||
# Add original content
|
||||
if original_content["readme"]:
|
||||
content += "# Original Model Information\n\n" + original_content["readme"]
|
||||
content += "## Original Model Card\n\n---\n\n" + original_content["readme"]
|
||||
else:
|
||||
content += f"## Original Model\n\nQuantisation of [{model_source.source_model}](https://huggingface.co/{model_source.source_model}).\n"
|
||||
content += f"## Original Model\n\nQuantisation of [{model_source.source_model}](https://huggingface.co/{model_source.source_model})."
|
||||
|
||||
return frontmatter + content
|
||||
|
||||
def _get_layers_config(self, quant_type: QuantisationType) -> dict[str, str]:
|
||||
"""Get layer configuration for quantisation type.
|
||||
|
||||
Returns layer precision specifications for the quantisation table.
|
||||
def _format_file_size(self, result: QuantisationResult) -> str:
|
||||
"""Format file size for README table.
|
||||
|
||||
Returns:
|
||||
Dictionary with embeddings, attention, and ffn precision labels.
|
||||
Formatted file size string or dash if not available.
|
||||
"""
|
||||
configs = {
|
||||
QuantisationType.Q4_K_M: {
|
||||
"embeddings": "Q4_K_M",
|
||||
"attention": "Q4_K_M",
|
||||
"ffn": "Q4_K_M",
|
||||
},
|
||||
QuantisationType.Q4_K_L: {"embeddings": "Q6_K", "attention": "Q6_K", "ffn": "Q4_K_M"},
|
||||
QuantisationType.Q4_K_XL: {"embeddings": "Q8_0", "attention": "Q6_K", "ffn": "Q4_K_M"},
|
||||
QuantisationType.Q4_K_XXL: {"embeddings": "Q8_0", "attention": "Q8_0", "ffn": "Q4_K_M"},
|
||||
}
|
||||
return configs.get(
|
||||
quant_type, {"embeddings": "Unknown", "attention": "Unknown", "ffn": "Unknown"}
|
||||
)
|
||||
if hasattr(result, "file_size") and result.file_size:
|
||||
return result.file_size
|
||||
if hasattr(result, "success") and result.success and hasattr(result, "file_path"):
|
||||
# Try to get file size from path if available
|
||||
try:
|
||||
if result.file_path and Path(result.file_path).exists():
|
||||
size_bytes = Path(result.file_path).stat().st_size
|
||||
size_gb = size_bytes / GIBIBYTE
|
||||
return f"{size_gb:.1f}GB"
|
||||
except Exception:
|
||||
pass
|
||||
return "-"
|
||||
|
||||
def _format_status(
|
||||
self,
|
||||
|
@ -402,7 +561,7 @@ GGUF quantisations of [{model_source.source_model}]({hf_url}) using Bartowski's
|
|||
Formatted status string for table cell.
|
||||
"""
|
||||
status_map = {
|
||||
"planned": "⏳ Planned",
|
||||
"planned": "⏳ Queued",
|
||||
"processing": "🔄 Processing...",
|
||||
"uploading": "⬆️ Uploading...",
|
||||
"failed": "❌ Failed",
|
||||
|
|
|
@ -1,198 +1,42 @@
|
|||
"""llama.cpp environment and operations service.
|
||||
"""Importance matrix (imatrix) management service.
|
||||
|
||||
Manages llama.cpp binary discovery, environment setup, and imatrix generation.
|
||||
Provides consistent interface for interacting with llama.cpp tools across
|
||||
different installation methods.
|
||||
Manages detection and use of existing importance matrix files for
|
||||
quantisation guidance. Provides user prompts for supplying pre-computed
|
||||
imatrix files from external sources.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from helpers.logger import logger
|
||||
from helpers.models.quantisation import LlamaCppEnvironment
|
||||
from helpers.services.filesystem import FilesystemService
|
||||
|
||||
|
||||
class EnvironmentManager:
|
||||
"""Manages llama.cpp environment setup and binary discovery.
|
||||
|
||||
Handles detection of local binaries, repository setup, and conversion
|
||||
script location. Provides fallback strategies for different installation
|
||||
scenarios including local builds and repository-based setups.
|
||||
"""
|
||||
|
||||
def __init__(self, work_dir: Path) -> None:
|
||||
"""Initialise EnvironmentManager."""
|
||||
self.work_dir = work_dir
|
||||
self.llama_cpp_dir = work_dir / "llama.cpp"
|
||||
self.fs = FilesystemService()
|
||||
|
||||
def setup(self) -> LlamaCppEnvironment:
|
||||
"""Set up llama.cpp environment with automatic detection.
|
||||
|
||||
Checks for local llama.cpp binaries first, then falls back to
|
||||
repository-based setup if needed. Handles conversion script location,
|
||||
dependency installation, and path resolution.
|
||||
|
||||
Returns:
|
||||
Configured LlamaCppEnvironment instance.
|
||||
"""
|
||||
# Check for local binaries first
|
||||
local_env = self._check_local_binaries()
|
||||
if local_env:
|
||||
return local_env
|
||||
|
||||
# Setup repository if needed
|
||||
return self.setup_repository()
|
||||
|
||||
def _check_local_binaries(self) -> LlamaCppEnvironment | None:
|
||||
"""Check for existing llama.cpp binaries in current directory.
|
||||
|
||||
Searches for quantise and CLI binaries in the current directory
|
||||
and standard installation paths. Also locates conversion scripts.
|
||||
|
||||
Returns:
|
||||
LlamaCppEnvironment if binaries found, None otherwise.
|
||||
"""
|
||||
quantise_bin = Path("./llama-quantize")
|
||||
cli_bin = Path("./llama-cli")
|
||||
|
||||
if not (quantise_bin.exists() and cli_bin.exists()):
|
||||
return None
|
||||
|
||||
logger.info("Found llama.cpp binaries in current directory")
|
||||
|
||||
# Check for conversion script
|
||||
convert_script = self._find_convert_script()
|
||||
if convert_script:
|
||||
logger.info(f"Found conversion script: {convert_script}")
|
||||
return LlamaCppEnvironment(
|
||||
quantise_binary=quantise_bin.resolve(),
|
||||
cli_binary=cli_bin.resolve(),
|
||||
convert_script=convert_script,
|
||||
use_repo=False,
|
||||
)
|
||||
|
||||
logger.warning("No conversion script found in current directory")
|
||||
logger.info("Will use llama.cpp repository method for conversion")
|
||||
return LlamaCppEnvironment(
|
||||
quantise_binary=quantise_bin.resolve(),
|
||||
cli_binary=cli_bin.resolve(),
|
||||
convert_script=f"python3 {self.llama_cpp_dir}/convert_hf_to_gguf.py",
|
||||
use_repo=True,
|
||||
)
|
||||
|
||||
def _find_convert_script(self) -> str | None:
|
||||
"""Find conversion script in current directory.
|
||||
|
||||
Searches for various naming conventions of the HF to GGUF
|
||||
conversion script.
|
||||
|
||||
Returns:
|
||||
Command to run conversion script, or None if not found.
|
||||
"""
|
||||
scripts = [
|
||||
"./llama-convert-hf-to-gguf",
|
||||
"python3 ./convert_hf_to_gguf.py",
|
||||
"python3 ./convert-hf-to-gguf.py",
|
||||
]
|
||||
|
||||
for script in scripts:
|
||||
if script.startswith("python3"):
|
||||
script_path = script.split(" ", 1)[1]
|
||||
if Path(script_path).exists():
|
||||
return script
|
||||
elif Path(script).exists():
|
||||
return script
|
||||
return None
|
||||
|
||||
def setup_repository(self) -> LlamaCppEnvironment:
|
||||
"""Setup llama.cpp repository for conversion scripts.
|
||||
|
||||
Clones the llama.cpp repository if not present and installs
|
||||
Python dependencies for model conversion.
|
||||
|
||||
Returns:
|
||||
LlamaCppEnvironment configured with repository paths.
|
||||
"""
|
||||
if not self.llama_cpp_dir.exists():
|
||||
logger.info("Cloning llama.cpp for conversion script...")
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"https://github.com/ggerganov/llama.cpp.git",
|
||||
str(self.llama_cpp_dir),
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Install Python requirements
|
||||
logger.info("Installing Python requirements...")
|
||||
subprocess.run(
|
||||
[
|
||||
"pip3",
|
||||
"install",
|
||||
"-r",
|
||||
"requirements.txt",
|
||||
"--break-system-packages",
|
||||
"--root-user-action=ignore",
|
||||
],
|
||||
cwd=self.llama_cpp_dir,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Install additional conversion dependencies
|
||||
logger.info("Installing additional conversion dependencies...")
|
||||
subprocess.run(
|
||||
[
|
||||
"pip3",
|
||||
"install",
|
||||
"transformers",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"--break-system-packages",
|
||||
"--root-user-action=ignore",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
else:
|
||||
logger.info("llama.cpp repository already exists")
|
||||
|
||||
# Use local binaries but repo conversion script
|
||||
return LlamaCppEnvironment(
|
||||
quantise_binary=Path("./llama-quantize").resolve(),
|
||||
cli_binary=Path("./llama-cli").resolve(),
|
||||
convert_script=f"python3 {self.llama_cpp_dir}/convert_hf_to_gguf.py",
|
||||
use_repo=False,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class IMatrixGenerator:
|
||||
"""Handles importance matrix generation for quantisation guidance.
|
||||
class IMatrixManager:
|
||||
"""Handles importance matrix file management for quantisation.
|
||||
|
||||
Generates or locates importance matrices that guide quantisation
|
||||
decisions, helping preserve model quality by identifying critical
|
||||
tensors requiring higher precision.
|
||||
Locates existing importance matrix files or prompts users to provide
|
||||
pre-computed matrices from external sources. These matrices guide
|
||||
quantisation decisions to preserve model quality.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise IMatrixGenerator."""
|
||||
"""Initialise IMatrixManager."""
|
||||
self.fs = FilesystemService()
|
||||
|
||||
def generate_imatrix(
|
||||
self, f16_model_path: Path, llama_env: LlamaCppEnvironment, model_dir: Path
|
||||
) -> Path | None:
|
||||
"""Generate importance matrix for quantisation guidance.
|
||||
def find_imatrix(self, model_dir: Path) -> Path | None:
|
||||
"""Find or prompt for importance matrix file.
|
||||
|
||||
Searches for existing imatrix files first, provides interactive
|
||||
prompts for user-supplied matrices, then generates new matrices
|
||||
using calibration data if necessary.
|
||||
Searches for existing imatrix files first, then provides interactive
|
||||
prompts for user-supplied matrices. See docs/imatrix_data.md for
|
||||
instructions on generating imatrix files.
|
||||
|
||||
Returns:
|
||||
Path to imatrix file, or None if generation fails.
|
||||
Path to imatrix file, or None if not available.
|
||||
"""
|
||||
imatrix_path = model_dir / "imatrix.dat"
|
||||
|
||||
|
@ -202,16 +46,7 @@ class IMatrixGenerator:
|
|||
return imatrix_path
|
||||
|
||||
# Try user-provided imatrix
|
||||
user_imatrix = self._prompt_for_user_imatrix(model_dir, imatrix_path)
|
||||
if user_imatrix:
|
||||
return user_imatrix
|
||||
|
||||
# Generate new imatrix
|
||||
calibration_file = self._get_calibration_file()
|
||||
if not calibration_file:
|
||||
return None
|
||||
|
||||
return self._generate_new_imatrix(f16_model_path, llama_env, imatrix_path, calibration_file)
|
||||
return self._prompt_for_user_imatrix(model_dir, imatrix_path)
|
||||
|
||||
def _prompt_for_user_imatrix(self, model_dir: Path, imatrix_path: Path) -> Path | None:
|
||||
"""Prompt user for existing imatrix file.
|
||||
|
@ -221,197 +56,28 @@ class IMatrixGenerator:
|
|||
"""
|
||||
logger.info(f"Model directory: {model_dir}")
|
||||
logger.info(f"Looking for imatrix file at: {imatrix_path}")
|
||||
logger.info(
|
||||
"Tip: You can download pre-computed imatrix files from Bartowski's repositories!"
|
||||
)
|
||||
logger.info(
|
||||
" Example: https://huggingface.co/bartowski/MODEL-NAME-GGUF/resolve/main/MODEL-NAME.imatrix"
|
||||
)
|
||||
logger.info("\n" + "=" * 70)
|
||||
logger.info("📊 No existing imatrix file found")
|
||||
logger.info("\nYou have two options:")
|
||||
logger.info(" 1. Provide a pre-computed imatrix file")
|
||||
logger.info(" (💡 see docs/imatrix_data.md to generate your own)")
|
||||
logger.info(" 2. Skip imatrix usage (lower quality quantisation)")
|
||||
logger.info("=" * 70)
|
||||
|
||||
response = (
|
||||
input("\n❓ Do you have an imatrix file to place in the model directory? (y/N): ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
response = input("\n❓ Do you have an imatrix file to provide? (y/N): ").strip().lower()
|
||||
|
||||
if response != "y":
|
||||
logger.info("Continuing without imatrix (quantisation quality may be lower)")
|
||||
logger.info("ℹ️ See docs/imatrix_data.md for instructions on generating imatrix files") # noqa: RUF001
|
||||
return None
|
||||
|
||||
logger.info(f"Please place your imatrix.dat file in: {model_dir}")
|
||||
input("⏳ Press Enter when you've placed the imatrix.dat file (or Ctrl+C to cancel)...")
|
||||
logger.info(f"\nPlease place your imatrix.dat file in: {model_dir}")
|
||||
input("⏳ Press Enter when you've placed the file (or Ctrl+C to cancel)...")
|
||||
|
||||
if imatrix_path.exists():
|
||||
file_size = self.fs.get_file_size(imatrix_path)
|
||||
logger.info(f"Found imatrix file! ({file_size})")
|
||||
logger.info(f"✅ Found imatrix file! ({file_size})")
|
||||
return imatrix_path
|
||||
|
||||
logger.warning("No imatrix.dat file found - continuing with automatic generation")
|
||||
return None
|
||||
|
||||
def _get_calibration_file(self) -> Path | None:
|
||||
"""Get calibration data file for imatrix generation.
|
||||
|
||||
Returns:
|
||||
Path to calibration file, or None if not found.
|
||||
"""
|
||||
calibration_file = Path(__file__).parent.parent.parent / "resources" / "imatrix_data.txt"
|
||||
if not calibration_file.exists():
|
||||
logger.warning("resources/imatrix_data.txt not found - skipping imatrix generation")
|
||||
logger.info(
|
||||
"Download from: https://gist.githubusercontent.com/bartowski1182/"
|
||||
"eb213dccb3571f863da82e99418f81e8/raw/calibration_datav3.txt"
|
||||
)
|
||||
return None
|
||||
return calibration_file
|
||||
|
||||
def _generate_new_imatrix(
|
||||
self,
|
||||
f16_model_path: Path,
|
||||
llama_env: LlamaCppEnvironment,
|
||||
imatrix_path: Path,
|
||||
calibration_file: Path,
|
||||
) -> Path | None:
|
||||
"""Generate new importance matrix using calibration data.
|
||||
|
||||
Returns:
|
||||
Path to generated imatrix, or None if generation fails.
|
||||
"""
|
||||
logger.info("Generating importance matrix (this may take 1-4 hours for large models)...")
|
||||
logger.info(f"Model: {f16_model_path.name}")
|
||||
logger.info(f"Calibration: {calibration_file}")
|
||||
logger.info(f"Output: {imatrix_path}")
|
||||
|
||||
# Find imatrix binary
|
||||
imatrix_binary = self._find_imatrix_binary(llama_env)
|
||||
if not imatrix_binary:
|
||||
logger.warning("llama-imatrix binary not found - skipping imatrix generation")
|
||||
logger.info("Make sure llama-imatrix is in the same directory as llama-quantize")
|
||||
return None
|
||||
|
||||
# Build and execute command
|
||||
cmd = self._build_imatrix_command(
|
||||
imatrix_binary, f16_model_path, calibration_file, imatrix_path
|
||||
)
|
||||
return self._execute_imatrix_generation(cmd, imatrix_path)
|
||||
|
||||
def _build_imatrix_command(
|
||||
self, binary: Path, model_path: Path, calibration_file: Path, output_path: Path
|
||||
) -> list[str]:
|
||||
"""Build imatrix generation command.
|
||||
|
||||
Returns:
|
||||
Command arguments as list.
|
||||
"""
|
||||
return [
|
||||
str(binary),
|
||||
"-m",
|
||||
str(model_path),
|
||||
"-f",
|
||||
str(calibration_file),
|
||||
"-o",
|
||||
str(output_path),
|
||||
"--process-output",
|
||||
"--output-frequency",
|
||||
"10",
|
||||
"--save-frequency",
|
||||
"50",
|
||||
"-t",
|
||||
"8",
|
||||
"-c",
|
||||
"2048",
|
||||
"-b",
|
||||
"512",
|
||||
]
|
||||
|
||||
def _execute_imatrix_generation(self, cmd: list[str], imatrix_path: Path) -> Path | None:
|
||||
"""Execute imatrix generation command with real-time output.
|
||||
|
||||
Returns:
|
||||
Path to generated imatrix file, or None if generation fails.
|
||||
"""
|
||||
logger.info(f"Running: {' '.join(cmd)}")
|
||||
logger.info("Starting imatrix generation... (progress will be shown)")
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
self._stream_imatrix_output(process)
|
||||
|
||||
return_code = process.poll()
|
||||
if return_code == 0:
|
||||
return self._validate_imatrix_output(imatrix_path)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("imatrix generation cancelled by user")
|
||||
process.terminate()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"imatrix generation failed with exception: {e}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"imatrix generation failed with return code {return_code}")
|
||||
return None
|
||||
|
||||
def _stream_imatrix_output(self, process: subprocess.Popen) -> None:
|
||||
"""Stream imatrix generation output in real-time."""
|
||||
while True:
|
||||
if process.stdout is not None:
|
||||
output = process.stdout.readline()
|
||||
else:
|
||||
break
|
||||
if not output and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
line = output.strip()
|
||||
if self._should_log_imatrix_line(line):
|
||||
logger.info(line)
|
||||
|
||||
def _should_log_imatrix_line(self, line: str) -> bool:
|
||||
"""Determine if imatrix output line should be logged.
|
||||
|
||||
Returns:
|
||||
True if line should be logged, False otherwise.
|
||||
"""
|
||||
keywords = ["Computing imatrix", "perplexity:", "save_imatrix", "entries =", "ETA"]
|
||||
return any(keyword in line for keyword in keywords) or line.startswith("[")
|
||||
|
||||
def _validate_imatrix_output(self, imatrix_path: Path) -> Path | None:
|
||||
"""Validate generated imatrix file.
|
||||
|
||||
Returns:
|
||||
Path to imatrix if valid, None otherwise.
|
||||
"""
|
||||
if imatrix_path.exists():
|
||||
file_size = self.fs.get_file_size(imatrix_path)
|
||||
logger.info(f"imatrix generation successful! ({file_size})")
|
||||
return imatrix_path
|
||||
logger.error("imatrix generation completed but file not found")
|
||||
return None
|
||||
|
||||
def _find_imatrix_binary(self, llama_env: LlamaCppEnvironment) -> Path | None:
|
||||
"""Find llama-imatrix binary in common locations.
|
||||
|
||||
Searches for the imatrix binary in the current directory and
|
||||
standard installation paths.
|
||||
|
||||
Returns:
|
||||
Path to imatrix binary, or None if not found.
|
||||
"""
|
||||
candidates = [
|
||||
Path("./llama-imatrix"),
|
||||
llama_env.quantise_binary.parent / "llama-imatrix",
|
||||
Path("/usr/local/bin/llama-imatrix"),
|
||||
Path("/usr/bin/llama-imatrix"),
|
||||
]
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate.exists() and candidate.is_file():
|
||||
return candidate
|
||||
|
||||
logger.warning("No imatrix.dat file found - continuing without imatrix")
|
||||
return None
|
||||
|
|
756
helpers/services/llama_python.py
Normal file
756
helpers/services/llama_python.py
Normal file
|
@ -0,0 +1,756 @@
|
|||
"""Python API wrapper for llama-cpp-python quantisation operations.
|
||||
|
||||
Provides high-level Python interfaces for model quantisation using llama-cpp-python
|
||||
bindings. Implements partial tensor-specific quantisation support through embedding
|
||||
and output tensor type configuration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Never
|
||||
|
||||
import psutil
|
||||
|
||||
from helpers.logger import logger
|
||||
from helpers.services.gguf import GGUFConverter
|
||||
from helpers.utils.config_parser import ConfigParser
|
||||
from helpers.utils.tensor_mapping import TensorMapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from helpers.models.quantisation import QuantisationConfig
|
||||
|
||||
# Import llama_cpp when needed
|
||||
try:
|
||||
import llama_cpp
|
||||
from llama_cpp import llama_model_quantize_params
|
||||
|
||||
LLAMA_CPP_AVAILABLE = True
|
||||
except ImportError:
|
||||
LLAMA_CPP_AVAILABLE = False
|
||||
logger.warning("llama-cpp-python not available - falling back to binary mode")
|
||||
|
||||
|
||||
class LlamaCppPythonAPI:
|
||||
"""Python API wrapper for llama.cpp quantisation operations.
|
||||
|
||||
Provides direct Python access to quantisation functionality using llama-cpp-python
|
||||
bindings. Implements partial tensor-specific quantisation through token embedding
|
||||
and output tensor type configuration, which provides differentiation between
|
||||
Q4_K variants even without full per-layer tensor control.
|
||||
"""
|
||||
|
||||
# Mapping of custom variant prefixes to their base types
|
||||
VARIANT_BASE_MAPPING: ClassVar[dict[str, str]] = {
|
||||
"Q3_K_": "Q3_K_M",
|
||||
"Q4_K_": "Q4_K_M",
|
||||
"Q5_K_": "Q5_K_M",
|
||||
"Q6_K_": "Q6_K",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Check if llama-cpp-python is available for use.
|
||||
|
||||
Returns:
|
||||
True if llama-cpp-python bindings are installed and functional.
|
||||
"""
|
||||
return LLAMA_CPP_AVAILABLE
|
||||
|
||||
@staticmethod
|
||||
def get_quantisation_type(config_name: str) -> int:
|
||||
"""Map configuration name to llama_cpp quantisation type constant.
|
||||
|
||||
Supports a wide range of quantisation types from Q2 to Q8, including
|
||||
K-quants and legacy formats. Handles both simple formats (Q4_K_M, Q6_K)
|
||||
and custom suffixed variants (Q4_K_M_L, Q5_K_M_XL) by mapping them to
|
||||
their base types for llama-cpp-python compatibility.
|
||||
|
||||
Returns:
|
||||
llama_cpp quantisation type constant for base quantisation.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If llama-cpp-python is not available.
|
||||
ValueError: If the quantisation type is not supported.
|
||||
"""
|
||||
if not LLAMA_CPP_AVAILABLE:
|
||||
msg = "llama-cpp-python not available"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Normalise the config name to extract base type
|
||||
# E.g., "Q4_K_L" or "Q4_K_XL" -> "Q4_K_M" (default for Q4_K)
|
||||
# E.g., "Q4_K_M_XXL" -> "Q4_K_M"
|
||||
config_upper = config_name.upper()
|
||||
|
||||
# Direct mapping for exact matches
|
||||
type_mapping = {
|
||||
# Q2 variants (not recommended but supported)
|
||||
"Q2_K": llama_cpp.LLAMA_FTYPE_MOSTLY_Q2_K,
|
||||
"Q2_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q2_K_S,
|
||||
# Q3 K-quants
|
||||
"Q3_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q3_K_S,
|
||||
"Q3_K_M": llama_cpp.LLAMA_FTYPE_MOSTLY_Q3_K_M,
|
||||
# Q4 K-quants (most common)
|
||||
"Q4_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_K_S,
|
||||
"Q4_K_M": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_K_M,
|
||||
# Q5 K-quants
|
||||
"Q5_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_K_S,
|
||||
"Q5_K_M": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_K_M,
|
||||
# Q6_K (single variant)
|
||||
"Q6_K": llama_cpp.LLAMA_FTYPE_MOSTLY_Q6_K,
|
||||
# Q8_0 (highest common quantisation)
|
||||
"Q8_0": llama_cpp.LLAMA_FTYPE_MOSTLY_Q8_0,
|
||||
# Legacy quantisation formats
|
||||
"Q4_0": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_0,
|
||||
"Q4_1": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_1,
|
||||
"Q5_0": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_0,
|
||||
"Q5_1": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_1,
|
||||
# IQ (Integer Quantisation) variants - experimental
|
||||
"IQ2_XXS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_XXS,
|
||||
"IQ2_XS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_XS,
|
||||
"IQ2_S": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_S,
|
||||
"IQ2_M": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_M,
|
||||
"IQ3_XXS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_XXS,
|
||||
"IQ3_XS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_XS,
|
||||
"IQ3_S": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_S,
|
||||
"IQ3_M": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_M,
|
||||
"IQ4_NL": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ4_NL,
|
||||
"IQ4_XS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ4_XS,
|
||||
# Higher precision formats
|
||||
"F16": llama_cpp.LLAMA_FTYPE_MOSTLY_F16,
|
||||
"BF16": llama_cpp.LLAMA_FTYPE_MOSTLY_BF16,
|
||||
}
|
||||
|
||||
# Try direct lookup first
|
||||
if config_upper in type_mapping:
|
||||
return type_mapping[config_upper]
|
||||
|
||||
# Handle custom variants using base mapping
|
||||
for prefix, base_type in LlamaCppPythonAPI.VARIANT_BASE_MAPPING.items():
|
||||
if config_upper.startswith(prefix) and config_upper not in type_mapping:
|
||||
return type_mapping[base_type]
|
||||
|
||||
# If not found, raise an informative error
|
||||
supported = sorted(type_mapping.keys())
|
||||
msg = (
|
||||
f"Unsupported quantisation type: {config_name}\n"
|
||||
f"Supported types: {', '.join(supported)}\n"
|
||||
f"Custom variants like Q4_K_L, Q4_K_XL are also supported."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_type_value(type_name: str) -> int:
|
||||
"""Convert tensor type name to llama_cpp constant.
|
||||
|
||||
Maps string tensor type names to their corresponding llama_cpp integer
|
||||
constants for tensor-specific overrides. Provides the foundation for
|
||||
differentiated quantisation strategies across embedding and output layers.
|
||||
|
||||
Returns:
|
||||
Integer value for the tensor type, or 0 if not found.
|
||||
"""
|
||||
if not LLAMA_CPP_AVAILABLE:
|
||||
return 0
|
||||
|
||||
# Build mapping with variant consolidation
|
||||
# All Q3_K variants map to base Q3_K type, same for Q4_K and Q5_K
|
||||
type_mapping = LlamaCppPythonAPI._build_tensor_type_mapping()
|
||||
return type_mapping.get(type_name.upper(), 0)
|
||||
|
||||
@staticmethod
|
||||
def _build_tensor_type_mapping() -> dict[str, int]:
|
||||
"""Build tensor type mapping with variant consolidation.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping type names to GGML constants.
|
||||
"""
|
||||
if not LLAMA_CPP_AVAILABLE:
|
||||
return {}
|
||||
|
||||
# Base mappings
|
||||
return {
|
||||
# Q2 variants
|
||||
"Q2_K": llama_cpp.GGML_TYPE_Q2_K,
|
||||
# Q3 variants - all map to base Q3_K
|
||||
"Q3_K": llama_cpp.GGML_TYPE_Q3_K,
|
||||
"Q3_K_S": llama_cpp.GGML_TYPE_Q3_K,
|
||||
"Q3_K_M": llama_cpp.GGML_TYPE_Q3_K,
|
||||
"Q3_K_L": llama_cpp.GGML_TYPE_Q3_K,
|
||||
# Q4 variants
|
||||
"Q4_0": llama_cpp.GGML_TYPE_Q4_0,
|
||||
"Q4_1": llama_cpp.GGML_TYPE_Q4_1,
|
||||
"Q4_K": llama_cpp.GGML_TYPE_Q4_K,
|
||||
"Q4_K_S": llama_cpp.GGML_TYPE_Q4_K,
|
||||
"Q4_K_M": llama_cpp.GGML_TYPE_Q4_K,
|
||||
# Q5 variants
|
||||
"Q5_0": llama_cpp.GGML_TYPE_Q5_0,
|
||||
"Q5_1": llama_cpp.GGML_TYPE_Q5_1,
|
||||
"Q5_K": llama_cpp.GGML_TYPE_Q5_K,
|
||||
"Q5_K_S": llama_cpp.GGML_TYPE_Q5_K,
|
||||
"Q5_K_M": llama_cpp.GGML_TYPE_Q5_K,
|
||||
# Q6 variant
|
||||
"Q6_K": llama_cpp.GGML_TYPE_Q6_K,
|
||||
# Q8 variant
|
||||
"Q8_0": llama_cpp.GGML_TYPE_Q8_0,
|
||||
# Higher precision
|
||||
"F16": llama_cpp.GGML_TYPE_F16,
|
||||
"F32": llama_cpp.GGML_TYPE_F32,
|
||||
}
|
||||
|
||||
def quantise_model_flexible(
|
||||
self,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
base_type: str,
|
||||
embedding_type: str | None = None,
|
||||
output_type: str | None = None,
|
||||
imatrix_path: Path | None = None,
|
||||
) -> bool:
|
||||
"""Quantise model with flexible tensor type configuration.
|
||||
|
||||
Provides control over base quantisation type with optional overrides for
|
||||
embeddings and output layers, which are the only tensor-specific controls
|
||||
that work reliably with llama-cpp-python.
|
||||
|
||||
Args:
|
||||
input_path: Path to input GGUF model.
|
||||
output_path: Path for output quantised model.
|
||||
base_type: Base quantisation type (e.g., "Q4_K_M", "Q6_K").
|
||||
embedding_type: Override for token embeddings (None = use base).
|
||||
output_type: Override for output/lm_head layers (None = use base).
|
||||
imatrix_path: Optional importance matrix file.
|
||||
|
||||
Returns:
|
||||
True if quantisation successful, False otherwise.
|
||||
|
||||
Examples:
|
||||
# Q4_K_L: Q4_K_M base with Q8_0 embeddings
|
||||
api.quantise_model_flexible(
|
||||
input_path, output_path, "Q4_K_M",
|
||||
embedding_type="Q8_0"
|
||||
)
|
||||
|
||||
# Q3_K_L: Q3_K_M base with Q5_K output
|
||||
api.quantise_model_flexible(
|
||||
input_path, output_path, "Q3_K_M",
|
||||
output_type="Q5_K"
|
||||
)
|
||||
|
||||
# Q3_K_XL: Q3_K_M with both Q8_0 embeddings and Q5_K output
|
||||
api.quantise_model_flexible(
|
||||
input_path, output_path, "Q3_K_M",
|
||||
embedding_type="Q8_0",
|
||||
output_type="Q5_K"
|
||||
)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If llama-cpp-python is not available.
|
||||
"""
|
||||
if not LLAMA_CPP_AVAILABLE:
|
||||
msg = "llama-cpp-python not available for quantisation"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
logger.info(f"🔄 Flexible quantisation: {base_type} base")
|
||||
logger.info(f"📝 Input: {input_path}")
|
||||
logger.info(f"📝 Output: {output_path}")
|
||||
|
||||
# Setup phase - create and configure parameters
|
||||
params = self._create_params(base_type, imatrix_path)
|
||||
self._apply_tensor_overrides(params, embedding_type, output_type)
|
||||
|
||||
# Execution phase - perform quantisation
|
||||
try:
|
||||
logger.debug("DEBUG: Starting flexible quantisation execution")
|
||||
result = self._do_quantisation(input_path, output_path, params)
|
||||
logger.debug(f"DEBUG: Flexible quantisation returned: {result}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Flexible quantisation failed with exception: {e}")
|
||||
logger.error("Flexible quantisation traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
return False
|
||||
else:
|
||||
if result == 0:
|
||||
# Verify output file was created and is valid
|
||||
if not output_path.exists():
|
||||
logger.error(
|
||||
f"❌ Quantisation claimed success but output does not exist: {output_path}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
output_size = output_path.stat().st_size
|
||||
logger.debug(f"DEBUG: Output file size: {output_size / (1024**3):.2f} GB")
|
||||
|
||||
if output_size == 0:
|
||||
logger.error("❌ Output file is empty despite success code")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Could not check output file size: {e}")
|
||||
|
||||
logger.info(f"✅ Quantisation successful: {output_path.name}")
|
||||
return True
|
||||
logger.error(f"❌ Quantisation failed with code: {result}")
|
||||
return False
|
||||
|
||||
def _create_params(
|
||||
self, base_type: str, imatrix_path: Path | None
|
||||
) -> llama_model_quantize_params:
|
||||
"""Create quantisation parameters.
|
||||
|
||||
Returns:
|
||||
Configured quantisation parameters.
|
||||
"""
|
||||
params = llama_model_quantize_params()
|
||||
params.ftype = self.get_quantisation_type(base_type)
|
||||
params.nthread = 8
|
||||
params.allow_requantize = True
|
||||
|
||||
if imatrix_path and imatrix_path.exists():
|
||||
# Convert path to bytes and create c_char_p, then cast to c_void_p
|
||||
imatrix_bytes = str(imatrix_path).encode("utf-8")
|
||||
char_p = ctypes.c_char_p(imatrix_bytes)
|
||||
params.imatrix = ctypes.cast(char_p, ctypes.c_void_p)
|
||||
logger.info(f"🧮 Using imatrix: {imatrix_path.name}")
|
||||
|
||||
return params
|
||||
|
||||
def _apply_tensor_overrides(
|
||||
self,
|
||||
params: llama_model_quantize_params,
|
||||
embedding_type: str | None,
|
||||
output_type: str | None,
|
||||
) -> None:
|
||||
"""Apply embedding and output tensor type overrides to params.
|
||||
|
||||
These are the only tensor-specific controls that work reliably
|
||||
with llama-cpp-python.
|
||||
"""
|
||||
# Apply embedding override if specified
|
||||
if embedding_type:
|
||||
params.token_embedding_type = self.get_tensor_type_value(embedding_type)
|
||||
logger.info(f"⚙️ Token embedding type: {embedding_type}")
|
||||
|
||||
# Apply output override if specified
|
||||
if output_type:
|
||||
params.output_tensor_type = self.get_tensor_type_value(output_type)
|
||||
params.quantize_output_tensor = True
|
||||
logger.info(f"⚙️ Output tensor type: {output_type}")
|
||||
|
||||
def _do_quantisation(
|
||||
self,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: llama_model_quantize_params,
|
||||
) -> int:
|
||||
"""Perform the quantisation operation.
|
||||
|
||||
Returns:
|
||||
Return code (0 for success).
|
||||
|
||||
Raises:
|
||||
KeyboardInterrupt: If the user interrupts the quantisation process.
|
||||
SystemExit: If the system exits during quantisation.
|
||||
"""
|
||||
logger.debug("DEBUG: Calling llama_cpp.llama_model_quantize")
|
||||
try:
|
||||
# Flush any pending output before calling C library
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
# Temporarily redirect stderr to prevent terminal control issues
|
||||
# Some GGUF models output control sequences that can break the terminal
|
||||
old_stderr_fd = None
|
||||
devnull_fd = None
|
||||
|
||||
try:
|
||||
# Only redirect if not in debug mode to preserve error messages
|
||||
if not logger.isEnabledFor(logging.DEBUG):
|
||||
old_stderr_fd = os.dup(2) # Save current stderr
|
||||
devnull_fd = os.open(os.devnull, os.O_WRONLY)
|
||||
os.dup2(devnull_fd, 2) # Redirect stderr to /dev/null
|
||||
|
||||
# Call the quantization with proper exception handling
|
||||
result = llama_cpp.llama_model_quantize(
|
||||
str(input_path).encode("utf-8"), str(output_path).encode("utf-8"), params
|
||||
)
|
||||
|
||||
finally:
|
||||
# Restore stderr if we redirected it
|
||||
if old_stderr_fd is not None:
|
||||
os.dup2(old_stderr_fd, 2)
|
||||
os.close(old_stderr_fd)
|
||||
if devnull_fd is not None:
|
||||
os.close(devnull_fd)
|
||||
|
||||
# Flush output after the call
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
except KeyboardInterrupt:
|
||||
logger.error("❌ Quantisation interrupted by user")
|
||||
raise
|
||||
except SystemExit as e:
|
||||
logger.error(f"❌ System exit during quantisation: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ llama_model_quantize call failed: {e}")
|
||||
logger.error("llama_model_quantize call traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
raise
|
||||
else:
|
||||
logger.debug(f"DEBUG: llama_model_quantize completed with code: {result}")
|
||||
return result
|
||||
|
||||
def quantise_model(
|
||||
self,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
config: QuantisationConfig,
|
||||
imatrix_path: Path | None = None,
|
||||
) -> bool:
|
||||
"""Quantise model using Python API.
|
||||
|
||||
Performs quantisation using llama-cpp-python's direct API access with
|
||||
support for embedding and output tensor type overrides. The L and XL
|
||||
variants use a base type with specific overrides.
|
||||
|
||||
Returns:
|
||||
True if quantisation successful, False otherwise.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If llama-cpp-python is not available.
|
||||
"""
|
||||
if not LLAMA_CPP_AVAILABLE:
|
||||
msg = "llama-cpp-python not available for quantisation"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Force cleanup before starting
|
||||
gc.collect()
|
||||
|
||||
# Log initial resource state
|
||||
mem_before = self._log_resource_state("before")
|
||||
|
||||
try:
|
||||
# Validate input
|
||||
if not self._validate_input_file(input_path):
|
||||
return False
|
||||
# Setup parameters
|
||||
params = self._setup_quantisation_params(config, imatrix_path)
|
||||
if params is None:
|
||||
return False
|
||||
# Execute quantisation
|
||||
result = self._execute_quantisation(input_path, output_path, params)
|
||||
# Verify and finalize
|
||||
if result == 0:
|
||||
return self._finalize_successful_quantisation(output_path, mem_before)
|
||||
|
||||
logger.error(f"❌ Quantisation failed with code: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Quantisation failed with exception: {e}")
|
||||
logger.error("Full quantisation traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
# Garbage collect and return false
|
||||
gc.collect()
|
||||
return False
|
||||
|
||||
def _log_resource_state(self, phase: str) -> float:
|
||||
"""Log current resource usage state.
|
||||
|
||||
Args:
|
||||
phase: Description of current phase (e.g., "before", "after").
|
||||
|
||||
Returns:
|
||||
Current memory usage in GB.
|
||||
"""
|
||||
process = psutil.Process()
|
||||
memory_gb = process.memory_info().rss / (1024**3)
|
||||
logger.debug(f"DEBUG: Memory {phase} quantisation: {memory_gb:.2f} GB")
|
||||
logger.debug(f"DEBUG: Open file descriptors: {len(process.open_files())}")
|
||||
if phase == "before":
|
||||
logger.debug(f"DEBUG: Process PID: {process.pid}")
|
||||
return memory_gb
|
||||
|
||||
def _validate_input_file(self, input_path: Path) -> bool:
|
||||
"""Validate input file exists and is readable.
|
||||
|
||||
Args:
|
||||
input_path: Path to input file.
|
||||
|
||||
Returns:
|
||||
True if file is valid, False otherwise.
|
||||
"""
|
||||
logger.debug(f"DEBUG: Starting quantisation of {input_path.name}")
|
||||
logger.info(f"🔄 Quantising {input_path.name}...")
|
||||
logger.debug(f"DEBUG: Input: {input_path}")
|
||||
|
||||
if not input_path.exists():
|
||||
logger.error(f"❌ Input file does not exist: {input_path}")
|
||||
return False
|
||||
|
||||
if not input_path.is_file():
|
||||
logger.error(f"❌ Input path is not a file: {input_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
input_size = input_path.stat().st_size
|
||||
logger.debug(f"DEBUG: Input file size: {input_size / (1024**3):.2f} GB")
|
||||
if input_size == 0:
|
||||
logger.error("❌ Input file is empty")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Could not check input file size: {e}")
|
||||
|
||||
return True
|
||||
|
||||
def _setup_quantisation_params(
|
||||
self,
|
||||
config: QuantisationConfig,
|
||||
imatrix_path: Path | None,
|
||||
) -> llama_model_quantize_params | None:
|
||||
"""Setup quantisation parameters.
|
||||
|
||||
Args:
|
||||
config: Quantisation configuration.
|
||||
imatrix_path: Optional path to importance matrix.
|
||||
|
||||
Returns:
|
||||
Configured parameters or None if setup failed.
|
||||
"""
|
||||
logger.debug("DEBUG: Setting up quantisation parameters")
|
||||
params = llama_model_quantize_params()
|
||||
|
||||
# Set base quantisation type
|
||||
try:
|
||||
params.ftype = self.get_quantisation_type(config.base_type)
|
||||
logger.debug(
|
||||
f"DEBUG: Set ftype to {params.ftype} for {config.base_type} (config: {config.name})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to get quantisation type for {config.name}: {e}")
|
||||
return None
|
||||
|
||||
# Configure basic parameters
|
||||
params.nthread = 8
|
||||
params.allow_requantize = True
|
||||
logger.debug(
|
||||
f"DEBUG: Set nthread={params.nthread}, allow_requantize={params.allow_requantize}"
|
||||
)
|
||||
|
||||
# Add imatrix if available
|
||||
if imatrix_path and imatrix_path.exists():
|
||||
try:
|
||||
# Convert path to bytes and create c_char_p, then cast to c_void_p
|
||||
imatrix_bytes = str(imatrix_path).encode("utf-8")
|
||||
char_p = ctypes.c_char_p(imatrix_bytes)
|
||||
params.imatrix = ctypes.cast(char_p, ctypes.c_void_p)
|
||||
logger.info(f"🧮 Using imatrix: {imatrix_path.name}")
|
||||
logger.debug(f"DEBUG: imatrix path set: {imatrix_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to set imatrix: {e}")
|
||||
# Continue without imatrix
|
||||
|
||||
# Configure tensor-specific types
|
||||
logger.debug("DEBUG: Configuring tensor-specific types")
|
||||
try:
|
||||
self._configure_tensor_types(params, config)
|
||||
logger.debug("DEBUG: Tensor types configured successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to configure tensor types: {e}")
|
||||
logger.error("Tensor type configuration traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
# Continue with default types
|
||||
|
||||
return params
|
||||
|
||||
def _execute_quantisation(
|
||||
self,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: llama_model_quantize_params,
|
||||
) -> int:
|
||||
"""Execute the actual quantisation with signal handling.
|
||||
|
||||
Args:
|
||||
input_path: Path to input model.
|
||||
output_path: Path for output model.
|
||||
params: Configured quantisation parameters.
|
||||
|
||||
Returns:
|
||||
Return code from quantisation (0 for success).
|
||||
"""
|
||||
logger.debug("DEBUG: Starting llama_cpp.llama_model_quantize call")
|
||||
logger.debug("DEBUG: About to call llama_model_quantize...")
|
||||
|
||||
# Setup signal handlers
|
||||
old_handlers = self._setup_signal_handlers()
|
||||
|
||||
try:
|
||||
result = llama_cpp.llama_model_quantize(
|
||||
str(input_path).encode("utf-8"), str(output_path).encode("utf-8"), params
|
||||
)
|
||||
logger.debug(f"DEBUG: llama_model_quantize returned: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ llama_model_quantize raised exception: {e}")
|
||||
logger.error("llama_model_quantize traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
return -1
|
||||
else:
|
||||
return result
|
||||
finally:
|
||||
self._restore_signal_handlers(old_handlers)
|
||||
|
||||
def _setup_signal_handlers(self) -> tuple[Any, Any | None]:
|
||||
"""Setup signal handlers for debugging termination.
|
||||
|
||||
Returns:
|
||||
Tuple of (old_sigterm, old_sigsegv) handlers.
|
||||
"""
|
||||
|
||||
def signal_debug_handler(signum: int, frame: object) -> Never: # noqa: ARG001
|
||||
logger.error(f"DEBUG: Received signal {signum} during quantisation!")
|
||||
logger.error(f"DEBUG: Signal name: {signal.Signals(signum).name}")
|
||||
msg = f"Signal {signum} received"
|
||||
raise KeyboardInterrupt(msg)
|
||||
|
||||
old_sigterm = signal.signal(signal.SIGTERM, signal_debug_handler)
|
||||
old_sigsegv = (
|
||||
signal.signal(signal.SIGSEGV, signal_debug_handler)
|
||||
if hasattr(signal, "SIGSEGV")
|
||||
else None
|
||||
)
|
||||
return old_sigterm, old_sigsegv
|
||||
|
||||
def _restore_signal_handlers(self, handlers: tuple[Any, Any | None]) -> None:
|
||||
"""Restore original signal handlers.
|
||||
|
||||
Args:
|
||||
handlers: Tuple of (old_sigterm, old_sigsegv) handlers.
|
||||
"""
|
||||
old_sigterm, old_sigsegv = handlers
|
||||
signal.signal(signal.SIGTERM, old_sigterm)
|
||||
if old_sigsegv is not None:
|
||||
signal.signal(signal.SIGSEGV, old_sigsegv)
|
||||
|
||||
def _finalize_successful_quantisation(
|
||||
self,
|
||||
output_path: Path,
|
||||
mem_before: float,
|
||||
) -> bool:
|
||||
"""Finalize successful quantisation and verify output.
|
||||
|
||||
Args:
|
||||
output_path: Path to output file.
|
||||
mem_before: Memory usage before quantisation in GB.
|
||||
|
||||
Returns:
|
||||
True if output is valid, False otherwise.
|
||||
"""
|
||||
logger.debug("DEBUG: Quantisation returned success code")
|
||||
|
||||
# Verify output exists
|
||||
if not output_path.exists():
|
||||
logger.error(
|
||||
f"❌ Quantisation claimed success but output does not exist: {output_path}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Verify output size
|
||||
output_size = output_path.stat().st_size
|
||||
logger.debug(f"DEBUG: Output file size: {output_size / (1024**3):.2f} GB")
|
||||
|
||||
if output_size == 0:
|
||||
logger.error("❌ Output file is empty despite success code")
|
||||
return False
|
||||
|
||||
logger.info(f"✅ Quantisation successful: {output_path.name}")
|
||||
|
||||
# Force cleanup and log final state
|
||||
gc.collect()
|
||||
mem_after = self._log_resource_state("after")
|
||||
logger.debug(f"DEBUG: Memory delta: {mem_after - mem_before:+.2f} GB")
|
||||
|
||||
return True
|
||||
|
||||
def _configure_tensor_types(
|
||||
self, params: llama_model_quantize_params, config: QuantisationConfig
|
||||
) -> None:
|
||||
"""Configure tensor-specific quantisation types.
|
||||
|
||||
Sets embedding and output tensor type overrides based on config.
|
||||
These are the only tensor-specific controls that work reliably
|
||||
with llama-cpp-python.
|
||||
"""
|
||||
logger.debug(f"DEBUG: _configure_tensor_types called for {config.name}")
|
||||
|
||||
# Apply embedding override if specified
|
||||
if config.embedding_type:
|
||||
params.token_embedding_type = self.get_tensor_type_value(config.embedding_type)
|
||||
logger.info(f"⚙️ Token embedding type: {config.embedding_type}")
|
||||
|
||||
# Apply output override if specified
|
||||
if config.output_type:
|
||||
params.output_tensor_type = self.get_tensor_type_value(config.output_type)
|
||||
params.quantize_output_tensor = True
|
||||
logger.info(f"⚙️ Output tensor type: {config.output_type}")
|
||||
|
||||
def convert_hf_to_gguf(
|
||||
self, input_dir: Path, output_path: Path, output_type: str = "f16"
|
||||
) -> bool:
|
||||
"""Convert HuggingFace model to GGUF format using native Python converter.
|
||||
|
||||
Uses our GGUFConverter for SafeTensors models, providing full Python-based
|
||||
conversion without external dependencies.
|
||||
|
||||
Returns:
|
||||
True if conversion successful, False otherwise.
|
||||
"""
|
||||
logger.info(f"🔄 Converting {input_dir.name} to GGUF format...")
|
||||
logger.info(f"📝 Input: {input_dir}")
|
||||
logger.info(f"📝 Output: {output_path}")
|
||||
logger.info(f"📝 Type: {output_type}")
|
||||
|
||||
# Check for SafeTensors files
|
||||
safetensor_files = list(input_dir.glob("*.safetensors"))
|
||||
if not safetensor_files:
|
||||
logger.warning("⚠️ No SafeTensors files found in model directory")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load model configuration
|
||||
config_parser = ConfigParser()
|
||||
model_config = config_parser.load_model_config(input_dir)
|
||||
|
||||
# Get architecture mapping
|
||||
arch_name = model_config.architectures[0] if model_config.architectures else "llama"
|
||||
arch = config_parser.get_architecture_mapping(arch_name)
|
||||
|
||||
if arch != arch_name:
|
||||
logger.info(f"📝 Architecture mapping: {arch_name} → {arch}")
|
||||
|
||||
# Convert using GGUFConverter
|
||||
tensor_mapper = TensorMapper()
|
||||
success = GGUFConverter.convert_safetensors(
|
||||
input_dir, output_path, model_config, arch, tensor_mapper
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Conversion failed with exception: {e}")
|
||||
return False
|
||||
else:
|
||||
if success:
|
||||
logger.info("✅ Native Python conversion successful")
|
||||
return success
|
|
@ -7,12 +7,22 @@ status tracking, and cleanup operations for efficient resource utilisation.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import gc
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from helpers.config.quantisation_configs import QUANTISATION_CONFIGS, SUPPORTED_QUANTISATION_TYPES
|
||||
import psutil
|
||||
|
||||
from helpers.config.quantisation_configs import (
|
||||
DEFAULT_QUANTISATION_TYPES,
|
||||
QUANTISATION_CONFIGS,
|
||||
SUPPORTED_QUANTISATION_TYPES,
|
||||
)
|
||||
from helpers.logger import logger
|
||||
from helpers.models.quantisation import (
|
||||
ModelSource,
|
||||
|
@ -21,10 +31,13 @@ from helpers.models.quantisation import (
|
|||
QuantisationType,
|
||||
)
|
||||
from helpers.services.huggingface import ReadmeGenerator
|
||||
from helpers.services.llama_cpp import EnvironmentManager, IMatrixGenerator
|
||||
from helpers.services.llama_cpp import IMatrixManager
|
||||
from helpers.services.quantisation import HuggingFaceUploader, ModelManager, QuantisationEngine
|
||||
from helpers.utils.tensor_mapping import URLParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class QuantisationOrchestrator:
|
||||
|
@ -36,73 +49,134 @@ class QuantisationOrchestrator:
|
|||
|
||||
work_dir: Path = field(default_factory=lambda: Path.cwd() / "quantisation_work")
|
||||
use_imatrix: bool = True
|
||||
imatrix_base: str = "Q4_K_M"
|
||||
no_upload: bool = False
|
||||
custom_profiles: list[str] | None = None
|
||||
|
||||
# Service dependencies with factory defaults
|
||||
url_parser: URLParser = field(default_factory=URLParser)
|
||||
quantisation_engine: QuantisationEngine = field(default_factory=QuantisationEngine)
|
||||
imatrix_generator: IMatrixGenerator = field(default_factory=IMatrixGenerator)
|
||||
imatrix_manager: IMatrixManager = field(default_factory=IMatrixManager)
|
||||
readme_generator: ReadmeGenerator = field(default_factory=ReadmeGenerator)
|
||||
uploader: HuggingFaceUploader = field(default_factory=HuggingFaceUploader)
|
||||
|
||||
# Computed properties
|
||||
models_dir: Path = field(init=False)
|
||||
environment_manager: EnvironmentManager = field(init=False)
|
||||
model_manager: ModelManager = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialise computed properties after dataclass construction."""
|
||||
self.models_dir = self.work_dir / "models"
|
||||
self.environment_manager = EnvironmentManager(self.work_dir)
|
||||
self.model_manager = ModelManager(self.models_dir, self.environment_manager)
|
||||
self.model_manager = ModelManager(self.models_dir)
|
||||
|
||||
# Set up signal handlers for graceful exit tracking
|
||||
self._setup_signal_handlers()
|
||||
|
||||
def _setup_signal_handlers(self) -> None:
|
||||
"""Set up signal handlers to catch unexpected exits."""
|
||||
|
||||
def signal_handler(signum: int, frame: FrameType | None) -> None:
|
||||
logger.error(f"❌ Received signal {signum} ({signal.Signals(signum).name})")
|
||||
logger.error("Stack trace at signal:")
|
||||
if frame:
|
||||
for line in traceback.format_stack(frame):
|
||||
logger.error(f" {line.strip()}")
|
||||
logger.error("Exiting due to signal")
|
||||
sys.exit(1)
|
||||
|
||||
# Handle common termination signals
|
||||
for sig in [signal.SIGINT, signal.SIGTERM]:
|
||||
signal.signal(sig, signal_handler)
|
||||
|
||||
def get_quantisation_types(self) -> list[QuantisationType]:
|
||||
"""Get the quantisation types to use for this run.
|
||||
|
||||
Returns:
|
||||
List of QuantisationType enums to process.
|
||||
"""
|
||||
if self.custom_profiles:
|
||||
# Parse custom profiles from strings to QuantisationType
|
||||
result = []
|
||||
for profile_str in self.custom_profiles:
|
||||
try:
|
||||
profile = QuantisationType(profile_str.upper())
|
||||
if profile in SUPPORTED_QUANTISATION_TYPES:
|
||||
result.append(profile)
|
||||
else:
|
||||
logger.warning(f"Profile {profile_str} is not supported, skipping")
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid profile {profile_str}, skipping")
|
||||
return result or DEFAULT_QUANTISATION_TYPES
|
||||
return DEFAULT_QUANTISATION_TYPES
|
||||
|
||||
def quantise(self, url: str) -> dict[QuantisationType, QuantisationResult]:
|
||||
"""Main quantisation workflow orchestrating model processing from URL to upload.
|
||||
|
||||
Returns:
|
||||
dict[QuantisationType, QuantisationResult]: Quantisation results for each type.
|
||||
|
||||
Raises:
|
||||
KeyboardInterrupt: If the user interrupts the quantisation process.
|
||||
"""
|
||||
logger.info("Starting Bartowski quantisation process...")
|
||||
logger.debug(f"DEBUG: Input URL: {url}")
|
||||
logger.debug(f"DEBUG: Working directory: {self.work_dir}")
|
||||
logger.debug(f"DEBUG: Use imatrix: {self.use_imatrix}")
|
||||
logger.debug(f"DEBUG: No upload: {self.no_upload}")
|
||||
logger.debug(f"DEBUG: Custom profiles: {self.custom_profiles}")
|
||||
|
||||
# Setup and preparation
|
||||
model_source, llama_env, f16_model_path, imatrix_path, output_repo = (
|
||||
self._setup_environment(url)
|
||||
)
|
||||
try:
|
||||
# Setup and preparation
|
||||
logger.debug("DEBUG: Starting environment setup...")
|
||||
model_source, f16_model_path, imatrix_path, output_repo = self._setup_environment(url)
|
||||
logger.debug(f"DEBUG: Environment setup complete. F16 model: {f16_model_path}")
|
||||
|
||||
# Create initial repository
|
||||
self._create_initial_repository(model_source, output_repo)
|
||||
# Create initial repository
|
||||
logger.debug("DEBUG: Creating initial repository...")
|
||||
self._create_initial_repository(model_source, output_repo)
|
||||
logger.debug("DEBUG: Initial repository created")
|
||||
|
||||
# Execute all quantisations
|
||||
results = self._execute_quantisations(
|
||||
model_source, llama_env, f16_model_path, imatrix_path, output_repo
|
||||
)
|
||||
# Execute all quantisations
|
||||
logger.debug("DEBUG: Starting quantisation execution...")
|
||||
results = self._execute_quantisations(
|
||||
model_source, f16_model_path, imatrix_path, output_repo
|
||||
)
|
||||
logger.debug(f"DEBUG: Quantisation execution complete. Results: {len(results)} items")
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_files(f16_model_path, model_source)
|
||||
# Cleanup
|
||||
logger.debug("DEBUG: Starting cleanup...")
|
||||
self._cleanup_files(f16_model_path, model_source)
|
||||
logger.debug("DEBUG: Cleanup complete")
|
||||
|
||||
self._print_completion_summary(model_source, results, output_repo)
|
||||
return results
|
||||
self._print_completion_summary(model_source, results, output_repo)
|
||||
except KeyboardInterrupt:
|
||||
logger.error("❌ Process interrupted by user (Ctrl+C)")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Critical error in quantisation workflow: {e}")
|
||||
logger.error("Full traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
raise
|
||||
else:
|
||||
return results
|
||||
|
||||
def _setup_environment(self, url: str) -> tuple[ModelSource, Any, Path, Path | None, str]:
|
||||
def _setup_environment(self, url: str) -> tuple[ModelSource, Path, Path | None, str]:
|
||||
"""Setup environment and prepare model for quantisation.
|
||||
|
||||
Returns:
|
||||
Tuple of (model_source, llama_env, f16_model_path, imatrix_path, output_repo).
|
||||
Tuple of (model_source, f16_model_path, imatrix_path, output_repo).
|
||||
"""
|
||||
model_source = self.url_parser.parse(url)
|
||||
self._print_model_info(model_source)
|
||||
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
llama_env = self.environment_manager.setup()
|
||||
|
||||
f16_model_path = self.model_manager.prepare_model(model_source, llama_env)
|
||||
f16_model_path = self.model_manager.prepare_model(model_source)
|
||||
|
||||
imatrix_path = None
|
||||
if self.use_imatrix:
|
||||
logger.info("Generating importance matrix (imatrix)...")
|
||||
imatrix_path = self.imatrix_generator.generate_imatrix(
|
||||
f16_model_path, llama_env, self.models_dir / model_source.model_name
|
||||
logger.info("Checking for importance matrix (imatrix)...")
|
||||
imatrix_path = self.imatrix_manager.find_imatrix(
|
||||
self.models_dir / model_source.model_name
|
||||
)
|
||||
|
||||
output_repo = (
|
||||
|
@ -110,14 +184,15 @@ class QuantisationOrchestrator:
|
|||
f"{model_source.original_author}-{model_source.model_name}-GGUF"
|
||||
)
|
||||
|
||||
return model_source, llama_env, f16_model_path, imatrix_path, output_repo
|
||||
return model_source, f16_model_path, imatrix_path, output_repo
|
||||
|
||||
def _create_initial_repository(self, model_source: ModelSource, output_repo: str) -> None:
|
||||
"""Create initial repository with planned quantisations."""
|
||||
logger.info("Creating initial README with planned quantisations...")
|
||||
quantisation_types = self.get_quantisation_types()
|
||||
planned_results = {
|
||||
qt: QuantisationResult(quantisation_type=qt, success=False, status="planned")
|
||||
for qt in SUPPORTED_QUANTISATION_TYPES
|
||||
for qt in quantisation_types
|
||||
}
|
||||
readme_path = self.readme_generator.generate(
|
||||
model_source, planned_results, self.models_dir, output_repo
|
||||
|
@ -132,7 +207,6 @@ class QuantisationOrchestrator:
|
|||
def _execute_quantisations(
|
||||
self,
|
||||
model_source: ModelSource,
|
||||
llama_env: Any,
|
||||
f16_model_path: Path,
|
||||
imatrix_path: Path | None,
|
||||
output_repo: str,
|
||||
|
@ -143,23 +217,56 @@ class QuantisationOrchestrator:
|
|||
dict[QuantisationType, QuantisationResult]: Quantisation results for each type.
|
||||
"""
|
||||
results: dict[QuantisationType, QuantisationResult] = {}
|
||||
upload_futures: list[Future[None]] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1, thread_name_prefix="uploader") as upload_executor:
|
||||
for quant_type in SUPPORTED_QUANTISATION_TYPES:
|
||||
result = self._process_single_quantisation(
|
||||
quant_type,
|
||||
model_source,
|
||||
llama_env,
|
||||
f16_model_path,
|
||||
imatrix_path,
|
||||
output_repo,
|
||||
results,
|
||||
upload_executor,
|
||||
upload_futures,
|
||||
quantisation_types = self.get_quantisation_types()
|
||||
types_list = [qt.value for qt in quantisation_types]
|
||||
logger.info(f"Processing {len(quantisation_types)} quantisation types: {types_list}")
|
||||
|
||||
# Process with parallel uploads - quantise sequentially but upload in background
|
||||
upload_futures = []
|
||||
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="upload") as upload_executor:
|
||||
for i, quant_type in enumerate(quantisation_types, 1):
|
||||
logger.info(
|
||||
f"Processing quantisation {i}/{len(quantisation_types)}: {quant_type.value}"
|
||||
)
|
||||
results[quant_type] = result
|
||||
logger.debug(f"DEBUG: Starting quantisation {i}/{len(quantisation_types)}")
|
||||
logger.debug(f"DEBUG: Current type: {quant_type.value}")
|
||||
logger.debug(f"DEBUG: Results so far: {len(results)} completed")
|
||||
|
||||
try:
|
||||
result = self._process_single_quantisation(
|
||||
quant_type,
|
||||
model_source,
|
||||
f16_model_path,
|
||||
imatrix_path,
|
||||
output_repo,
|
||||
results,
|
||||
upload_executor,
|
||||
upload_futures,
|
||||
)
|
||||
results[quant_type] = result
|
||||
logger.debug(f"DEBUG: Quantisation {quant_type.value} completed")
|
||||
|
||||
# Force cleanup between quantisations
|
||||
gc.collect()
|
||||
logger.debug("DEBUG: Garbage collection completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Critical error processing {quant_type.value}: {e}")
|
||||
logger.error("Exception traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
results[quant_type] = QuantisationResult(
|
||||
quantisation_type=quant_type,
|
||||
success=False,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
# Force cleanup after error
|
||||
gc.collect()
|
||||
|
||||
# Wait for all uploads to complete before returning
|
||||
self._wait_for_uploads(upload_futures)
|
||||
|
||||
return results
|
||||
|
@ -168,7 +275,6 @@ class QuantisationOrchestrator:
|
|||
self,
|
||||
quant_type: QuantisationType,
|
||||
model_source: ModelSource,
|
||||
llama_env: Any,
|
||||
f16_model_path: Path,
|
||||
imatrix_path: Path | None,
|
||||
output_repo: str,
|
||||
|
@ -183,26 +289,33 @@ class QuantisationOrchestrator:
|
|||
"""
|
||||
try:
|
||||
logger.info(f"Starting {quant_type.value} quantisation...")
|
||||
logger.debug(f"DEBUG: Getting config for {quant_type.value}")
|
||||
config = QUANTISATION_CONFIGS[quant_type]
|
||||
logger.debug(f"DEBUG: Config loaded: {config.name}")
|
||||
|
||||
# Update status to processing
|
||||
logger.debug("DEBUG: Creating initial quantisation result")
|
||||
result = QuantisationResult(quantisation_type=quant_type, success=False)
|
||||
result.status = "processing"
|
||||
results[quant_type] = result
|
||||
|
||||
logger.debug("DEBUG: Updating README status")
|
||||
self._update_readme_status(model_source, results, output_repo)
|
||||
|
||||
# Perform quantisation
|
||||
logger.debug("DEBUG: Creating quantisation context")
|
||||
context = QuantisationContext(
|
||||
f16_model_path=f16_model_path,
|
||||
model_source=model_source,
|
||||
config=config,
|
||||
llama_env=llama_env,
|
||||
models_dir=self.models_dir,
|
||||
imatrix_path=imatrix_path,
|
||||
base_quant=self.imatrix_base,
|
||||
)
|
||||
logger.debug(f"DEBUG: Context created. F16 path: {f16_model_path}")
|
||||
logger.debug(f"DEBUG: imatrix path: {imatrix_path}")
|
||||
logger.debug("DEBUG: Calling quantisation engine...")
|
||||
result = self.quantisation_engine.quantise(context)
|
||||
logger.debug(f"DEBUG: Quantisation engine returned: success={result.success}")
|
||||
|
||||
self._handle_quantisation_result(
|
||||
result,
|
||||
|
@ -220,6 +333,108 @@ class QuantisationOrchestrator:
|
|||
else:
|
||||
return result
|
||||
|
||||
def _process_single_quantisation_sequential(
|
||||
self,
|
||||
quant_type: QuantisationType,
|
||||
model_source: ModelSource,
|
||||
f16_model_path: Path,
|
||||
imatrix_path: Path | None,
|
||||
output_repo: str,
|
||||
results: dict[QuantisationType, QuantisationResult],
|
||||
) -> QuantisationResult:
|
||||
"""Process a single quantisation type sequentially with immediate upload.
|
||||
|
||||
Returns:
|
||||
QuantisationResult: Result of the quantisation attempt.
|
||||
"""
|
||||
# Force cleanup before starting new quantisation
|
||||
gc.collect()
|
||||
|
||||
# Log system state before quantisation
|
||||
process = psutil.Process()
|
||||
logger.debug(f"DEBUG: === System state before {quant_type.value} ===")
|
||||
logger.debug(f"DEBUG: Process alive: {process.is_running()}")
|
||||
logger.debug(f"DEBUG: PID: {process.pid}")
|
||||
logger.debug(f"DEBUG: Memory: {process.memory_info().rss / (1024**3):.2f} GB")
|
||||
logger.debug(f"DEBUG: CPU percent: {process.cpu_percent()}%")
|
||||
logger.debug(f"DEBUG: Threads: {process.num_threads()}")
|
||||
logger.debug(f"DEBUG: Open files: {len(process.open_files())}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting {quant_type.value} quantisation...")
|
||||
logger.debug(f"DEBUG: Getting config for {quant_type.value}")
|
||||
config = QUANTISATION_CONFIGS[quant_type]
|
||||
logger.debug(f"DEBUG: Config loaded: {config.name}")
|
||||
|
||||
# Update status to processing
|
||||
logger.debug("DEBUG: Creating initial quantisation result")
|
||||
result = QuantisationResult(quantisation_type=quant_type, success=False)
|
||||
result.status = "processing"
|
||||
results[quant_type] = result
|
||||
|
||||
logger.debug("DEBUG: Updating README status")
|
||||
self._update_readme_status(model_source, results, output_repo)
|
||||
|
||||
# Perform quantisation
|
||||
logger.debug("DEBUG: Creating quantisation context")
|
||||
context = QuantisationContext(
|
||||
f16_model_path=f16_model_path,
|
||||
model_source=model_source,
|
||||
config=config,
|
||||
models_dir=self.models_dir,
|
||||
imatrix_path=imatrix_path,
|
||||
)
|
||||
logger.debug(f"DEBUG: Context created. F16 path: {f16_model_path}")
|
||||
logger.debug(f"DEBUG: imatrix path: {imatrix_path}")
|
||||
logger.debug("DEBUG: Calling quantisation engine...")
|
||||
result = self.quantisation_engine.quantise(context)
|
||||
logger.debug(f"DEBUG: Quantisation engine returned: success={result.success}")
|
||||
|
||||
if result.success and result.file_path:
|
||||
# Upload immediately (if not in no-upload mode)
|
||||
if not self.no_upload:
|
||||
logger.info(f"Uploading {quant_type.value}...")
|
||||
try:
|
||||
self.uploader.upload_model_file(output_repo, result.file_path)
|
||||
logger.info(f"Upload of {quant_type.value} completed successfully")
|
||||
|
||||
# Clean up file after successful upload
|
||||
logger.info(f"Removing {result.file_path.name} to save disk space...")
|
||||
result.file_path.unlink()
|
||||
|
||||
result.status = "completed"
|
||||
self._update_readme_status(model_source, results, output_repo)
|
||||
except Exception as upload_error:
|
||||
logger.error(f"Failed to upload {quant_type.value}: {upload_error}")
|
||||
result.status = "failed"
|
||||
result.error_message = str(upload_error)
|
||||
self._update_readme_status(model_source, results, output_repo)
|
||||
# Keep file if upload failed
|
||||
else:
|
||||
# No upload mode - just mark as completed
|
||||
result.status = "completed"
|
||||
logger.info(f"Skipping upload of {quant_type.value} (--no-upload specified)")
|
||||
else:
|
||||
result.status = "failed"
|
||||
self._update_readme_status(model_source, results, output_repo)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {quant_type.value}: {e}")
|
||||
result = QuantisationResult(quantisation_type=quant_type, success=False)
|
||||
result.status = "failed"
|
||||
result.error_message = str(e)
|
||||
|
||||
try:
|
||||
self._update_readme_status(model_source, results, output_repo)
|
||||
except Exception as readme_error:
|
||||
logger.error(f"Failed to update README after error: {readme_error}")
|
||||
# Force cleanup after error
|
||||
gc.collect()
|
||||
return result
|
||||
else:
|
||||
# Force cleanup after quantisation
|
||||
gc.collect()
|
||||
return result
|
||||
|
||||
def _handle_quantisation_result(
|
||||
self,
|
||||
result: QuantisationResult,
|
||||
|
@ -328,8 +543,9 @@ class QuantisationOrchestrator:
|
|||
) -> None:
|
||||
"""Upload file and clean up (runs in background thread)."""
|
||||
try:
|
||||
logger.info(f"[PARALLEL] Uploading {quant_type}...")
|
||||
logger.info(f"[PARALLEL] Starting upload of {quant_type.value} ({file_path.name})")
|
||||
self.uploader.upload_model_file(output_repo, file_path)
|
||||
logger.info(f"[PARALLEL] Upload of {quant_type.value} completed successfully")
|
||||
|
||||
logger.info(f"[PARALLEL] Removing {file_path.name} to save disk space...")
|
||||
file_path.unlink()
|
||||
|
@ -346,11 +562,16 @@ class QuantisationOrchestrator:
|
|||
results[quant_type].status = "failed"
|
||||
results[quant_type].error_message = str(e)
|
||||
|
||||
updated_readme_path = self.readme_generator.generate(
|
||||
model_source, results, self.models_dir, output_repo
|
||||
)
|
||||
self.uploader.upload_readme(output_repo, updated_readme_path)
|
||||
raise
|
||||
try:
|
||||
updated_readme_path = self.readme_generator.generate(
|
||||
model_source, results, self.models_dir, output_repo
|
||||
)
|
||||
self.uploader.upload_readme(output_repo, updated_readme_path)
|
||||
except Exception as readme_error:
|
||||
logger.error(
|
||||
f"[PARALLEL] Failed to update README after upload error: {readme_error}"
|
||||
)
|
||||
# Don't re-raise - let other uploads continue
|
||||
|
||||
def _print_model_info(self, model_source: ModelSource) -> None:
|
||||
"""Print model information."""
|
||||
|
|
|
@ -9,7 +9,9 @@ from __future__ import annotations
|
|||
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING
|
||||
import tempfile
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from helpers.logger import logger
|
||||
from helpers.models.quantisation import (
|
||||
|
@ -19,12 +21,10 @@ from helpers.models.quantisation import (
|
|||
QuantisationType,
|
||||
)
|
||||
from helpers.services.filesystem import FilesystemService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from helpers.models.quantisation import LlamaCppEnvironment
|
||||
from helpers.services.llama_cpp import EnvironmentManager
|
||||
from helpers.services.gguf import GGUFConverter
|
||||
from helpers.services.llama_python import LlamaCppPythonAPI
|
||||
from helpers.utils.config_parser import ConfigParser
|
||||
from helpers.utils.tensor_mapping import TensorMapper
|
||||
|
||||
|
||||
class QuantisationEngine:
|
||||
|
@ -32,145 +32,88 @@ class QuantisationEngine:
|
|||
|
||||
Provides flexible quantisation execution supporting multiple tensor
|
||||
precision configurations, importance matrices, and fallback strategies.
|
||||
Encapsulates llama-quantize binary interactions with real-time output.
|
||||
Uses llama-cpp-python API for direct quantisation without subprocess overhead.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise quantisation engine."""
|
||||
self.fs = FilesystemService()
|
||||
self.python_api = LlamaCppPythonAPI()
|
||||
|
||||
def quantise(self, context: QuantisationContext) -> QuantisationResult:
|
||||
"""Perform quantisation using the specified configuration.
|
||||
|
||||
Executes quantisation with primary and fallback methods, handling
|
||||
tensor-specific precision overrides and importance matrix guidance.
|
||||
Executes quantisation using Python API. Since llama-cpp-python is a
|
||||
required dependency, we can rely on it being available.
|
||||
|
||||
Returns:
|
||||
QuantisationResult with success status and file information.
|
||||
"""
|
||||
logger.debug(f"DEBUG: Starting quantisation for {context.config.name}")
|
||||
logger.info(
|
||||
f"⚙️ Creating {context.config.name} quantisation ({context.config.description})..."
|
||||
)
|
||||
|
||||
output_path = context.get_output_path()
|
||||
logger.debug(f"DEBUG: Output path: {output_path}")
|
||||
|
||||
logger.info(f"🎯 Attempting {context.config.name} quantisation...")
|
||||
logger.info(f"📝 Source: {context.f16_model_path}")
|
||||
logger.info(f"📝 Target: {output_path}")
|
||||
|
||||
# Try primary method
|
||||
if self._try_quantisation_method(
|
||||
context, output_path, context.config.tensor_types, "method 1"
|
||||
):
|
||||
return self._create_success_result(context.config.name, output_path, "method 1")
|
||||
|
||||
# Try fallback methods
|
||||
for i, fallback_method in enumerate(context.config.fallback_methods, 2):
|
||||
method_name = f"method {i}"
|
||||
if self._try_quantisation_method(context, output_path, fallback_method, method_name):
|
||||
return self._create_success_result(context.config.name, output_path, method_name)
|
||||
|
||||
logger.error("All %s quantisation methods failed", context.config.name)
|
||||
return QuantisationResult(
|
||||
quantisation_type=QuantisationType(context.config.name),
|
||||
success=False,
|
||||
error_message="All quantisation methods failed",
|
||||
)
|
||||
|
||||
def _try_quantisation_method(
|
||||
self,
|
||||
context: QuantisationContext,
|
||||
output_path: Path,
|
||||
tensor_config: dict[str, str],
|
||||
method_name: str,
|
||||
) -> bool:
|
||||
"""Try a specific quantisation method with real-time output.
|
||||
|
||||
Builds and executes llama-quantize command with appropriate parameters,
|
||||
streaming output for progress monitoring.
|
||||
|
||||
Returns:
|
||||
True if quantisation successful, False otherwise.
|
||||
"""
|
||||
logger.info(f"🔍 Trying {method_name}...")
|
||||
|
||||
cmd = self._build_quantisation_command(context, output_path, tensor_config)
|
||||
return self._execute_quantisation_command(cmd, method_name)
|
||||
|
||||
def _build_quantisation_command(
|
||||
self, context: QuantisationContext, output_path: Path, tensor_config: dict[str, str]
|
||||
) -> list[str]:
|
||||
"""Build quantisation command with all required parameters.
|
||||
|
||||
Returns:
|
||||
List of command arguments.
|
||||
"""
|
||||
cmd = [str(context.llama_env.quantise_binary)]
|
||||
|
||||
# Add imatrix if available
|
||||
if context.imatrix_path and context.imatrix_path.exists():
|
||||
cmd.extend(["--imatrix", str(context.imatrix_path)])
|
||||
logger.info(f"🧮 Using imatrix: {context.imatrix_path.name}")
|
||||
|
||||
# Add tensor type arguments
|
||||
self._add_tensor_type_arguments(cmd, tensor_config)
|
||||
|
||||
cmd.extend([str(context.f16_model_path), str(output_path), context.base_quant])
|
||||
return cmd
|
||||
|
||||
def _add_tensor_type_arguments(self, cmd: list[str], tensor_config: dict[str, str]) -> None:
|
||||
"""Add tensor type arguments to command."""
|
||||
if not tensor_config:
|
||||
return
|
||||
|
||||
for tensor_name, quant_type in tensor_config.items():
|
||||
if tensor_name.startswith(("token-embedding-type", "output-tensor-type")):
|
||||
cmd.extend([f"--{tensor_name}", quant_type])
|
||||
else:
|
||||
cmd.extend(["--tensor-type", f"{tensor_name}={quant_type}"])
|
||||
|
||||
def _execute_quantisation_command(self, cmd: list[str], method_name: str) -> bool:
|
||||
"""Execute quantisation command with real-time output.
|
||||
|
||||
Returns:
|
||||
True if quantisation successful, False otherwise.
|
||||
"""
|
||||
logger.info(f"💻 Running: {' '.join(cmd)}")
|
||||
logger.info("⏳ Quantisation in progress... (this may take several minutes)")
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
# Check input file exists and is readable
|
||||
if not context.f16_model_path.exists():
|
||||
error_msg = f"Input model file does not exist: {context.f16_model_path}"
|
||||
logger.error(f"❌ {error_msg}")
|
||||
return QuantisationResult(
|
||||
quantisation_type=QuantisationType(context.config.name),
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
self._stream_quantisation_output(process)
|
||||
|
||||
return_code = process.poll()
|
||||
if return_code == 0:
|
||||
logger.info(f"✅ {method_name} quantisation successful!")
|
||||
return True
|
||||
# Check if we have enough disk space (rough estimate)
|
||||
try:
|
||||
input_size = context.f16_model_path.stat().st_size
|
||||
logger.debug(f"DEBUG: Input file size: {input_size / (1024**3):.2f} GB")
|
||||
# This is a rough check - actual available space calculation is more complex
|
||||
logger.debug(f"DEBUG: Output directory: {output_path.parent}")
|
||||
except Exception as e:
|
||||
logger.info(f"❌ {method_name} failed with exception: {e}")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"❌ {method_name} failed with return code {return_code}")
|
||||
return False
|
||||
logger.warning(f"⚠️ Could not check disk space: {e}")
|
||||
|
||||
def _stream_quantisation_output(self, process: subprocess.Popen) -> None:
|
||||
"""Stream quantisation output in real-time."""
|
||||
while True:
|
||||
if process.stdout is not None:
|
||||
output = process.stdout.readline()
|
||||
else:
|
||||
break
|
||||
if not output and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
logger.info(f"📊 {output.strip()}")
|
||||
logger.info(f"🎯 Attempting {context.config.name} quantisation...")
|
||||
logger.debug(f"DEBUG: Source: {context.f16_model_path}")
|
||||
logger.debug(f"DEBUG: Target: {output_path}")
|
||||
logger.debug(f"DEBUG: imatrix: {context.imatrix_path}")
|
||||
|
||||
try:
|
||||
# Use Python API for quantisation
|
||||
logger.info("🐍 Using Python API for quantisation...")
|
||||
logger.debug("DEBUG: Calling python_api.quantise_model...")
|
||||
|
||||
success = self.python_api.quantise_model(
|
||||
context.f16_model_path, output_path, context.config, context.imatrix_path
|
||||
)
|
||||
|
||||
logger.debug(f"DEBUG: Python API returned: {success}")
|
||||
|
||||
if success:
|
||||
logger.debug("DEBUG: Quantisation successful, creating success result")
|
||||
return self._create_success_result(context.config.name, output_path, "Python API")
|
||||
|
||||
logger.error(f"❌ {context.config.name} quantisation failed")
|
||||
return QuantisationResult(
|
||||
quantisation_type=QuantisationType(context.config.name),
|
||||
success=False,
|
||||
error_message="Quantisation failed via Python API",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Exception during {context.config.name} quantisation: {e}")
|
||||
logger.error("Exception traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
|
||||
return QuantisationResult(
|
||||
quantisation_type=QuantisationType(context.config.name),
|
||||
success=False,
|
||||
error_message=f"Exception during quantisation: {e!s}",
|
||||
)
|
||||
|
||||
def _create_success_result(
|
||||
self, quant_type: str, output_path: Path, method_used: str
|
||||
|
@ -197,17 +140,15 @@ class ModelManager:
|
|||
providing unified interface for model acquisition and preparation.
|
||||
"""
|
||||
|
||||
def __init__(self, models_dir: Path, environment_manager: EnvironmentManager) -> None:
|
||||
"""Initialise model manager with storage and environment configuration.
|
||||
def __init__(self, models_dir: Path) -> None:
|
||||
"""Initialise model manager with storage configuration.
|
||||
|
||||
Sets up model storage directory and links to environment manager for
|
||||
conversion script access and llama.cpp tool discovery.
|
||||
Sets up model storage directory for model downloads and conversions.
|
||||
"""
|
||||
self.models_dir = models_dir
|
||||
self.environment_manager = environment_manager
|
||||
self.fs = FilesystemService()
|
||||
|
||||
def prepare_model(self, model_source: ModelSource, llama_env: LlamaCppEnvironment) -> Path:
|
||||
def prepare_model(self, model_source: ModelSource) -> Path:
|
||||
"""Prepare model for quantisation and return F16 model path.
|
||||
|
||||
Handles both GGUF repository downloads and regular HuggingFace model
|
||||
|
@ -220,7 +161,7 @@ class ModelManager:
|
|||
|
||||
if model_source.is_gguf_repo:
|
||||
return self._handle_gguf_repo(model_source, model_dir)
|
||||
return self._handle_regular_repo(model_source, model_dir, llama_env)
|
||||
return self._handle_regular_repo(model_source, model_dir)
|
||||
|
||||
def _handle_gguf_repo(self, model_source: ModelSource, model_dir: Path) -> Path:
|
||||
"""Handle GGUF repository download with pattern matching.
|
||||
|
@ -275,7 +216,6 @@ class ModelManager:
|
|||
return self._handle_regular_repo(
|
||||
ModelSource(**{**model_source.dict(), "is_gguf_repo": False}),
|
||||
model_dir,
|
||||
None,
|
||||
)
|
||||
|
||||
def _download_gguf_with_patterns(
|
||||
|
@ -308,7 +248,10 @@ class ModelManager:
|
|||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
logger.debug(
|
||||
f"DEBUG: Running huggingface-cli download for pattern {search_pattern}"
|
||||
)
|
||||
result = subprocess.run(
|
||||
[
|
||||
"timeout",
|
||||
"300",
|
||||
|
@ -322,6 +265,10 @@ class ModelManager:
|
|||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"DEBUG: Download command completed with return code {result.returncode}"
|
||||
)
|
||||
|
||||
# Find downloaded GGUF files
|
||||
|
@ -336,9 +283,22 @@ class ModelManager:
|
|||
shutil.rmtree(temp_dir)
|
||||
return final_path
|
||||
|
||||
except subprocess.CalledProcessError:
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.debug(
|
||||
f"DEBUG: Pattern {search_pattern} failed with return code {e.returncode}"
|
||||
)
|
||||
if e.stderr:
|
||||
logger.debug(f"DEBUG: stderr: {e.stderr}")
|
||||
if e.stdout:
|
||||
logger.debug(f"DEBUG: stdout: {e.stdout}")
|
||||
logger.info(f"⚠️ Pattern {search_pattern} failed or timed out")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Unexpected error during download: {e}")
|
||||
logger.error("Exception traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
continue
|
||||
finally:
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
@ -349,58 +309,123 @@ class ModelManager:
|
|||
self,
|
||||
model_source: ModelSource,
|
||||
model_dir: Path,
|
||||
llama_env: LlamaCppEnvironment | None,
|
||||
) -> Path:
|
||||
"""Handle regular HuggingFace repository conversion.
|
||||
|
||||
Downloads full model repository and converts to F16 GGUF format
|
||||
using llama.cpp conversion scripts.
|
||||
using our native Python-based GGUFConverter for SafeTensors models.
|
||||
|
||||
Returns:
|
||||
Path to converted F16 GGUF model.
|
||||
"""
|
||||
logger.info(f"⬇️ Downloading source model: {model_source.source_model}")
|
||||
|
||||
# Download model if needed
|
||||
if not model_dir.exists():
|
||||
subprocess.run(
|
||||
self._download_repository(model_source.source_model, model_dir)
|
||||
else:
|
||||
logger.info("✅ Model already downloaded")
|
||||
|
||||
# Convert to GGUF
|
||||
return self._convert_to_gguf(model_source, model_dir)
|
||||
|
||||
def _download_repository(self, source_model: str, model_dir: Path) -> None:
|
||||
"""Download HuggingFace repository.
|
||||
|
||||
Args:
|
||||
source_model: HuggingFace model identifier.
|
||||
model_dir: Local directory for download.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If download fails.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"DEBUG: Downloading full repository: {source_model}")
|
||||
result = subprocess.run(
|
||||
[
|
||||
"huggingface-cli",
|
||||
"download",
|
||||
model_source.source_model,
|
||||
source_model,
|
||||
"--local-dir",
|
||||
str(model_dir),
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
logger.info("✅ Model already downloaded")
|
||||
logger.debug(
|
||||
f"DEBUG: Repository download completed with return code {result.returncode}"
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Failed to download repository {source_model}")
|
||||
logger.error(f"Return code: {e.returncode}")
|
||||
if e.stderr:
|
||||
logger.error(f"stderr: {e.stderr}")
|
||||
if e.stdout:
|
||||
logger.error(f"stdout: {e.stdout}")
|
||||
msg = f"Repository download failed: {e}"
|
||||
raise RuntimeError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Unexpected error during repository download: {e}")
|
||||
logger.error("Exception traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
raise
|
||||
|
||||
def _convert_to_gguf(self, model_source: ModelSource, model_dir: Path) -> Path:
|
||||
"""Convert model to GGUF F16 format.
|
||||
|
||||
Args:
|
||||
model_source: Model source information.
|
||||
model_dir: Directory containing model files.
|
||||
|
||||
Returns:
|
||||
Path to F16 GGUF model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If conversion fails.
|
||||
"""
|
||||
logger.info("🔄 Converting to GGUF F16 format...")
|
||||
f16_model = model_dir / f"{model_source.model_name}-f16.gguf"
|
||||
|
||||
if not f16_model.exists():
|
||||
if not llama_env:
|
||||
llama_env = self.environment_manager.setup()
|
||||
|
||||
# Ensure conversion script is available
|
||||
if llama_env.use_repo or not self.environment_manager.llama_cpp_dir.exists():
|
||||
logger.info("Getting conversion script from llama.cpp repository...")
|
||||
llama_env = self.environment_manager.setup_repository()
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
*llama_env.convert_script.split(),
|
||||
str(model_dir),
|
||||
"--outtype",
|
||||
"f16",
|
||||
"--outfile",
|
||||
str(f16_model),
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
else:
|
||||
if f16_model.exists():
|
||||
logger.info("✅ F16 model already exists")
|
||||
return f16_model
|
||||
|
||||
# Check for SafeTensors files
|
||||
safetensor_files = list(model_dir.glob("*.safetensors"))
|
||||
if not safetensor_files:
|
||||
logger.error("❌ Model format not supported")
|
||||
logger.info("💡 This tool supports GGUF and SafeTensors formats")
|
||||
msg = "Model must be in GGUF or SafeTensors format"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
logger.info("🐍 Using native Python GGUFConverter...")
|
||||
logger.info(f"✅ Found {len(safetensor_files)} SafeTensors files")
|
||||
|
||||
# Load model configuration
|
||||
config_parser = ConfigParser()
|
||||
model_config = config_parser.load_model_config(model_dir)
|
||||
|
||||
# Get architecture mapping
|
||||
arch_name = model_config.architectures[0] if model_config.architectures else "llama"
|
||||
arch = config_parser.get_architecture_mapping(arch_name)
|
||||
|
||||
if arch != arch_name:
|
||||
logger.info(f"📝 Architecture mapping: {arch_name} → {arch}")
|
||||
|
||||
# Convert using GGUFConverter
|
||||
tensor_mapper = TensorMapper()
|
||||
success = GGUFConverter.convert_safetensors(
|
||||
model_dir, f16_model, model_config, arch, tensor_mapper
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error("❌ Native Python conversion failed")
|
||||
msg = "Failed to convert SafeTensors model to GGUF"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
logger.info("✅ Native Python conversion successful")
|
||||
return f16_model
|
||||
|
||||
|
||||
|
@ -437,50 +462,214 @@ class HuggingFaceUploader:
|
|||
"""Upload or update README file to repository.
|
||||
|
||||
Creates repository if needed, handles existing repository updates.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the README upload fails.
|
||||
"""
|
||||
logger.info("Uploading README...")
|
||||
|
||||
# First ensure the repository exists
|
||||
self._ensure_repo_exists(output_repo)
|
||||
|
||||
# Upload without --create flag to avoid PR creation
|
||||
try:
|
||||
subprocess.run(
|
||||
logger.debug(f"DEBUG: Uploading README to {output_repo}")
|
||||
result = subprocess.run(
|
||||
[
|
||||
"huggingface-cli",
|
||||
"upload",
|
||||
output_repo,
|
||||
str(readme_path),
|
||||
"README.md",
|
||||
"--create",
|
||||
"--commit-message",
|
||||
"Update README.md",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
logger.info("README uploaded")
|
||||
except subprocess.CalledProcessError:
|
||||
# Repository exists, update without --create
|
||||
logger.debug(f"DEBUG: README upload completed with return code {result.returncode}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Failed to upload README to {output_repo}")
|
||||
logger.error(f"Return code: {e.returncode}")
|
||||
if e.stderr:
|
||||
logger.error(f"stderr: {e.stderr}")
|
||||
if e.stdout:
|
||||
logger.error(f"stdout: {e.stdout}")
|
||||
msg = f"README upload failed: {e}"
|
||||
raise RuntimeError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Unexpected error during README upload: {e}")
|
||||
logger.error("Exception traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
raise
|
||||
logger.info("README uploaded")
|
||||
|
||||
def _ensure_repo_exists(self, repo_id: str) -> None:
|
||||
"""Ensure the repository exists, creating it if necessary."""
|
||||
try:
|
||||
# Try to create the repo - will fail if it already exists
|
||||
subprocess.run(
|
||||
[
|
||||
"huggingface-cli",
|
||||
"upload",
|
||||
output_repo,
|
||||
str(readme_path),
|
||||
"README.md",
|
||||
"repo",
|
||||
"create",
|
||||
repo_id,
|
||||
"--type",
|
||||
"model",
|
||||
"-y",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
logger.info("README updated")
|
||||
logger.info(f"Created repository: {repo_id}")
|
||||
except subprocess.CalledProcessError:
|
||||
# Repository already exists, that's fine
|
||||
pass
|
||||
|
||||
def upload_model_file(self, output_repo: str, model_path: Path) -> None:
|
||||
"""Upload model file to repository.
|
||||
|
||||
Uploads GGUF model file to specified repository path.
|
||||
Always uses huggingface-cli to ensure proper handling of large files
|
||||
via HuggingFace's xet backend.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model file upload fails.
|
||||
"""
|
||||
logger.info(f"Uploading {model_path.name}...")
|
||||
subprocess.run(
|
||||
[
|
||||
"huggingface-cli",
|
||||
"upload",
|
||||
output_repo,
|
||||
str(model_path),
|
||||
model_path.name,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Always use huggingface-cli for model files to ensure xet backend is used
|
||||
try:
|
||||
logger.debug(f"DEBUG: Uploading model file {model_path.name} to {output_repo}")
|
||||
result = subprocess.run(
|
||||
[
|
||||
"huggingface-cli",
|
||||
"upload",
|
||||
output_repo,
|
||||
str(model_path),
|
||||
model_path.name,
|
||||
"--revision",
|
||||
"main", # Explicitly push to main branch
|
||||
"--commit-message",
|
||||
f"Add {model_path.name}",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
logger.debug(f"DEBUG: Model upload completed with return code {result.returncode}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Failed to upload model file {model_path.name} to {output_repo}")
|
||||
logger.error(f"Return code: {e.returncode}")
|
||||
if e.stderr:
|
||||
logger.error(f"stderr: {e.stderr}")
|
||||
if e.stdout:
|
||||
logger.error(f"stdout: {e.stdout}")
|
||||
msg = f"Model file upload failed: {e}"
|
||||
raise RuntimeError(msg) from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Unexpected error during model file upload: {e}")
|
||||
logger.error("Exception traceback:")
|
||||
for line in traceback.format_exc().splitlines():
|
||||
logger.error(f" {line}")
|
||||
raise
|
||||
|
||||
# Extract and log the URL if present in output
|
||||
if result.stdout:
|
||||
for line in result.stdout.splitlines():
|
||||
if "https://huggingface.co/" in line:
|
||||
logger.info(f"Upload URL: {line.strip()}")
|
||||
break
|
||||
|
||||
logger.info(f"{model_path.name} uploaded")
|
||||
|
||||
def _try_git_upload_file(
|
||||
self,
|
||||
repo_id: str,
|
||||
local_path: Path,
|
||||
repo_path: str,
|
||||
*,
|
||||
create_repo: bool = False,
|
||||
) -> bool:
|
||||
"""Try to upload file using git directly to avoid PR creation.
|
||||
|
||||
Returns:
|
||||
bool: True if upload successful, False if should fallback to CLI.
|
||||
"""
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
repo_url = f"https://huggingface.co/{repo_id}"
|
||||
|
||||
# Clone repository
|
||||
logger.info(f"Cloning {repo_url}...")
|
||||
result = subprocess.run(
|
||||
["git", "clone", repo_url, str(temp_path / "repo")],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
if create_repo:
|
||||
# Repository doesn't exist, let huggingface-cli handle creation
|
||||
return False
|
||||
logger.warning(f"Clone failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
repo_dir = temp_path / "repo"
|
||||
target_file = repo_dir / repo_path
|
||||
|
||||
# Ensure target directory exists
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy file
|
||||
shutil.copy2(local_path, target_file)
|
||||
|
||||
# Check if there are any changes
|
||||
status_result = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
cwd=repo_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not status_result.stdout.strip():
|
||||
logger.info(f"No changes detected for {repo_path}, file already up-to-date")
|
||||
return True # File is already up-to-date, no need to push
|
||||
|
||||
# Git add, commit, push
|
||||
subprocess.run(
|
||||
["git", "add", repo_path],
|
||||
cwd=repo_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", f"Update {repo_path}"],
|
||||
cwd=repo_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "push"],
|
||||
cwd=repo_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Git upload failed: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Git upload error: {e}")
|
||||
return False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue