llm-gguf-tools/helpers/services/llama_python.py

756 lines
28 KiB
Python

"""Python API wrapper for llama-cpp-python quantisation operations.
Provides high-level Python interfaces for model quantisation using llama-cpp-python
bindings. Implements partial tensor-specific quantisation support through embedding
and output tensor type configuration.
"""
from __future__ import annotations
import ctypes
import gc
import logging
import os
import signal
import sys
import traceback
from typing import TYPE_CHECKING, Any, ClassVar, Never
import psutil
from helpers.logger import logger
from helpers.services.gguf import GGUFConverter
from helpers.utils.config_parser import ConfigParser
from helpers.utils.tensor_mapping import TensorMapper
if TYPE_CHECKING:
from pathlib import Path
from helpers.models.quantisation import QuantisationConfig
# Import llama_cpp when needed
try:
import llama_cpp
from llama_cpp import llama_model_quantize_params
LLAMA_CPP_AVAILABLE = True
except ImportError:
LLAMA_CPP_AVAILABLE = False
logger.warning("llama-cpp-python not available - falling back to binary mode")
class LlamaCppPythonAPI:
"""Python API wrapper for llama.cpp quantisation operations.
Provides direct Python access to quantisation functionality using llama-cpp-python
bindings. Implements partial tensor-specific quantisation through token embedding
and output tensor type configuration, which provides differentiation between
Q4_K variants even without full per-layer tensor control.
"""
# Mapping of custom variant prefixes to their base types
VARIANT_BASE_MAPPING: ClassVar[dict[str, str]] = {
"Q3_K_": "Q3_K_M",
"Q4_K_": "Q4_K_M",
"Q5_K_": "Q5_K_M",
"Q6_K_": "Q6_K",
}
@staticmethod
def is_available() -> bool:
"""Check if llama-cpp-python is available for use.
Returns:
True if llama-cpp-python bindings are installed and functional.
"""
return LLAMA_CPP_AVAILABLE
@staticmethod
def get_quantisation_type(config_name: str) -> int:
"""Map configuration name to llama_cpp quantisation type constant.
Supports a wide range of quantisation types from Q2 to Q8, including
K-quants and legacy formats. Handles both simple formats (Q4_K_M, Q6_K)
and custom suffixed variants (Q4_K_M_L, Q5_K_M_XL) by mapping them to
their base types for llama-cpp-python compatibility.
Returns:
llama_cpp quantisation type constant for base quantisation.
Raises:
RuntimeError: If llama-cpp-python is not available.
ValueError: If the quantisation type is not supported.
"""
if not LLAMA_CPP_AVAILABLE:
msg = "llama-cpp-python not available"
raise RuntimeError(msg)
# Normalise the config name to extract base type
# E.g., "Q4_K_L" or "Q4_K_XL" -> "Q4_K_M" (default for Q4_K)
# E.g., "Q4_K_M_XXL" -> "Q4_K_M"
config_upper = config_name.upper()
# Direct mapping for exact matches
type_mapping = {
# Q2 variants (not recommended but supported)
"Q2_K": llama_cpp.LLAMA_FTYPE_MOSTLY_Q2_K,
"Q2_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q2_K_S,
# Q3 K-quants
"Q3_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q3_K_S,
"Q3_K_M": llama_cpp.LLAMA_FTYPE_MOSTLY_Q3_K_M,
# Q4 K-quants (most common)
"Q4_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_K_S,
"Q4_K_M": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_K_M,
# Q5 K-quants
"Q5_K_S": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_K_S,
"Q5_K_M": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_K_M,
# Q6_K (single variant)
"Q6_K": llama_cpp.LLAMA_FTYPE_MOSTLY_Q6_K,
# Q8_0 (highest common quantisation)
"Q8_0": llama_cpp.LLAMA_FTYPE_MOSTLY_Q8_0,
# Legacy quantisation formats
"Q4_0": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_0,
"Q4_1": llama_cpp.LLAMA_FTYPE_MOSTLY_Q4_1,
"Q5_0": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_0,
"Q5_1": llama_cpp.LLAMA_FTYPE_MOSTLY_Q5_1,
# IQ (Integer Quantisation) variants - experimental
"IQ2_XXS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_XXS,
"IQ2_XS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_XS,
"IQ2_S": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_S,
"IQ2_M": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ2_M,
"IQ3_XXS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_XXS,
"IQ3_XS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_XS,
"IQ3_S": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_S,
"IQ3_M": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ3_M,
"IQ4_NL": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ4_NL,
"IQ4_XS": llama_cpp.LLAMA_FTYPE_MOSTLY_IQ4_XS,
# Higher precision formats
"F16": llama_cpp.LLAMA_FTYPE_MOSTLY_F16,
"BF16": llama_cpp.LLAMA_FTYPE_MOSTLY_BF16,
}
# Try direct lookup first
if config_upper in type_mapping:
return type_mapping[config_upper]
# Handle custom variants using base mapping
for prefix, base_type in LlamaCppPythonAPI.VARIANT_BASE_MAPPING.items():
if config_upper.startswith(prefix) and config_upper not in type_mapping:
return type_mapping[base_type]
# If not found, raise an informative error
supported = sorted(type_mapping.keys())
msg = (
f"Unsupported quantisation type: {config_name}\n"
f"Supported types: {', '.join(supported)}\n"
f"Custom variants like Q4_K_L, Q4_K_XL are also supported."
)
raise ValueError(msg)
@staticmethod
def get_tensor_type_value(type_name: str) -> int:
"""Convert tensor type name to llama_cpp constant.
Maps string tensor type names to their corresponding llama_cpp integer
constants for tensor-specific overrides. Provides the foundation for
differentiated quantisation strategies across embedding and output layers.
Returns:
Integer value for the tensor type, or 0 if not found.
"""
if not LLAMA_CPP_AVAILABLE:
return 0
# Build mapping with variant consolidation
# All Q3_K variants map to base Q3_K type, same for Q4_K and Q5_K
type_mapping = LlamaCppPythonAPI._build_tensor_type_mapping()
return type_mapping.get(type_name.upper(), 0)
@staticmethod
def _build_tensor_type_mapping() -> dict[str, int]:
"""Build tensor type mapping with variant consolidation.
Returns:
Dictionary mapping type names to GGML constants.
"""
if not LLAMA_CPP_AVAILABLE:
return {}
# Base mappings
return {
# Q2 variants
"Q2_K": llama_cpp.GGML_TYPE_Q2_K,
# Q3 variants - all map to base Q3_K
"Q3_K": llama_cpp.GGML_TYPE_Q3_K,
"Q3_K_S": llama_cpp.GGML_TYPE_Q3_K,
"Q3_K_M": llama_cpp.GGML_TYPE_Q3_K,
"Q3_K_L": llama_cpp.GGML_TYPE_Q3_K,
# Q4 variants
"Q4_0": llama_cpp.GGML_TYPE_Q4_0,
"Q4_1": llama_cpp.GGML_TYPE_Q4_1,
"Q4_K": llama_cpp.GGML_TYPE_Q4_K,
"Q4_K_S": llama_cpp.GGML_TYPE_Q4_K,
"Q4_K_M": llama_cpp.GGML_TYPE_Q4_K,
# Q5 variants
"Q5_0": llama_cpp.GGML_TYPE_Q5_0,
"Q5_1": llama_cpp.GGML_TYPE_Q5_1,
"Q5_K": llama_cpp.GGML_TYPE_Q5_K,
"Q5_K_S": llama_cpp.GGML_TYPE_Q5_K,
"Q5_K_M": llama_cpp.GGML_TYPE_Q5_K,
# Q6 variant
"Q6_K": llama_cpp.GGML_TYPE_Q6_K,
# Q8 variant
"Q8_0": llama_cpp.GGML_TYPE_Q8_0,
# Higher precision
"F16": llama_cpp.GGML_TYPE_F16,
"F32": llama_cpp.GGML_TYPE_F32,
}
def quantise_model_flexible(
self,
input_path: Path,
output_path: Path,
base_type: str,
embedding_type: str | None = None,
output_type: str | None = None,
imatrix_path: Path | None = None,
) -> bool:
"""Quantise model with flexible tensor type configuration.
Provides control over base quantisation type with optional overrides for
embeddings and output layers, which are the only tensor-specific controls
that work reliably with llama-cpp-python.
Args:
input_path: Path to input GGUF model.
output_path: Path for output quantised model.
base_type: Base quantisation type (e.g., "Q4_K_M", "Q6_K").
embedding_type: Override for token embeddings (None = use base).
output_type: Override for output/lm_head layers (None = use base).
imatrix_path: Optional importance matrix file.
Returns:
True if quantisation successful, False otherwise.
Examples:
# Q4_K_L: Q4_K_M base with Q8_0 embeddings
api.quantise_model_flexible(
input_path, output_path, "Q4_K_M",
embedding_type="Q8_0"
)
# Q3_K_L: Q3_K_M base with Q5_K output
api.quantise_model_flexible(
input_path, output_path, "Q3_K_M",
output_type="Q5_K"
)
# Q3_K_XL: Q3_K_M with both Q8_0 embeddings and Q5_K output
api.quantise_model_flexible(
input_path, output_path, "Q3_K_M",
embedding_type="Q8_0",
output_type="Q5_K"
)
Raises:
RuntimeError: If llama-cpp-python is not available.
"""
if not LLAMA_CPP_AVAILABLE:
msg = "llama-cpp-python not available for quantisation"
raise RuntimeError(msg)
logger.info(f"🔄 Flexible quantisation: {base_type} base")
logger.info(f"📝 Input: {input_path}")
logger.info(f"📝 Output: {output_path}")
# Setup phase - create and configure parameters
params = self._create_params(base_type, imatrix_path)
self._apply_tensor_overrides(params, embedding_type, output_type)
# Execution phase - perform quantisation
try:
logger.debug("DEBUG: Starting flexible quantisation execution")
result = self._do_quantisation(input_path, output_path, params)
logger.debug(f"DEBUG: Flexible quantisation returned: {result}")
except Exception as e:
logger.error(f"❌ Flexible quantisation failed with exception: {e}")
logger.error("Flexible quantisation traceback:")
for line in traceback.format_exc().splitlines():
logger.error(f" {line}")
return False
else:
if result == 0:
# Verify output file was created and is valid
if not output_path.exists():
logger.error(
f"❌ Quantisation claimed success but output does not exist: {output_path}"
)
return False
try:
output_size = output_path.stat().st_size
logger.debug(f"DEBUG: Output file size: {output_size / (1024**3):.2f} GB")
if output_size == 0:
logger.error("❌ Output file is empty despite success code")
return False
except Exception as e:
logger.warning(f"⚠️ Could not check output file size: {e}")
logger.info(f"✅ Quantisation successful: {output_path.name}")
return True
logger.error(f"❌ Quantisation failed with code: {result}")
return False
def _create_params(
self, base_type: str, imatrix_path: Path | None
) -> llama_model_quantize_params:
"""Create quantisation parameters.
Returns:
Configured quantisation parameters.
"""
params = llama_model_quantize_params()
params.ftype = self.get_quantisation_type(base_type)
params.nthread = 8
params.allow_requantize = True
if imatrix_path and imatrix_path.exists():
# Convert path to bytes and create c_char_p, then cast to c_void_p
imatrix_bytes = str(imatrix_path).encode("utf-8")
char_p = ctypes.c_char_p(imatrix_bytes)
params.imatrix = ctypes.cast(char_p, ctypes.c_void_p)
logger.info(f"🧮 Using imatrix: {imatrix_path.name}")
return params
def _apply_tensor_overrides(
self,
params: llama_model_quantize_params,
embedding_type: str | None,
output_type: str | None,
) -> None:
"""Apply embedding and output tensor type overrides to params.
These are the only tensor-specific controls that work reliably
with llama-cpp-python.
"""
# Apply embedding override if specified
if embedding_type:
params.token_embedding_type = self.get_tensor_type_value(embedding_type)
logger.info(f"⚙️ Token embedding type: {embedding_type}")
# Apply output override if specified
if output_type:
params.output_tensor_type = self.get_tensor_type_value(output_type)
params.quantize_output_tensor = True
logger.info(f"⚙️ Output tensor type: {output_type}")
def _do_quantisation(
self,
input_path: Path,
output_path: Path,
params: llama_model_quantize_params,
) -> int:
"""Perform the quantisation operation.
Returns:
Return code (0 for success).
Raises:
KeyboardInterrupt: If the user interrupts the quantisation process.
SystemExit: If the system exits during quantisation.
"""
logger.debug("DEBUG: Calling llama_cpp.llama_model_quantize")
try:
# Flush any pending output before calling C library
sys.stdout.flush()
sys.stderr.flush()
# Temporarily redirect stderr to prevent terminal control issues
# Some GGUF models output control sequences that can break the terminal
old_stderr_fd = None
devnull_fd = None
try:
# Only redirect if not in debug mode to preserve error messages
if not logger.isEnabledFor(logging.DEBUG):
old_stderr_fd = os.dup(2) # Save current stderr
devnull_fd = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull_fd, 2) # Redirect stderr to /dev/null
# Call the quantization with proper exception handling
result = llama_cpp.llama_model_quantize(
str(input_path).encode("utf-8"), str(output_path).encode("utf-8"), params
)
finally:
# Restore stderr if we redirected it
if old_stderr_fd is not None:
os.dup2(old_stderr_fd, 2)
os.close(old_stderr_fd)
if devnull_fd is not None:
os.close(devnull_fd)
# Flush output after the call
sys.stdout.flush()
sys.stderr.flush()
except KeyboardInterrupt:
logger.error("❌ Quantisation interrupted by user")
raise
except SystemExit as e:
logger.error(f"❌ System exit during quantisation: {e}")
raise
except Exception as e:
logger.error(f"❌ llama_model_quantize call failed: {e}")
logger.error("llama_model_quantize call traceback:")
for line in traceback.format_exc().splitlines():
logger.error(f" {line}")
raise
else:
logger.debug(f"DEBUG: llama_model_quantize completed with code: {result}")
return result
def quantise_model(
self,
input_path: Path,
output_path: Path,
config: QuantisationConfig,
imatrix_path: Path | None = None,
) -> bool:
"""Quantise model using Python API.
Performs quantisation using llama-cpp-python's direct API access with
support for embedding and output tensor type overrides. The L and XL
variants use a base type with specific overrides.
Returns:
True if quantisation successful, False otherwise.
Raises:
RuntimeError: If llama-cpp-python is not available.
"""
if not LLAMA_CPP_AVAILABLE:
msg = "llama-cpp-python not available for quantisation"
raise RuntimeError(msg)
# Force cleanup before starting
gc.collect()
# Log initial resource state
mem_before = self._log_resource_state("before")
try:
# Validate input
if not self._validate_input_file(input_path):
return False
# Setup parameters
params = self._setup_quantisation_params(config, imatrix_path)
if params is None:
return False
# Execute quantisation
result = self._execute_quantisation(input_path, output_path, params)
# Verify and finalize
if result == 0:
return self._finalize_successful_quantisation(output_path, mem_before)
logger.error(f"❌ Quantisation failed with code: {result}")
except Exception as e:
logger.error(f"❌ Quantisation failed with exception: {e}")
logger.error("Full quantisation traceback:")
for line in traceback.format_exc().splitlines():
logger.error(f" {line}")
# Garbage collect and return false
gc.collect()
return False
def _log_resource_state(self, phase: str) -> float:
"""Log current resource usage state.
Args:
phase: Description of current phase (e.g., "before", "after").
Returns:
Current memory usage in GB.
"""
process = psutil.Process()
memory_gb = process.memory_info().rss / (1024**3)
logger.debug(f"DEBUG: Memory {phase} quantisation: {memory_gb:.2f} GB")
logger.debug(f"DEBUG: Open file descriptors: {len(process.open_files())}")
if phase == "before":
logger.debug(f"DEBUG: Process PID: {process.pid}")
return memory_gb
def _validate_input_file(self, input_path: Path) -> bool:
"""Validate input file exists and is readable.
Args:
input_path: Path to input file.
Returns:
True if file is valid, False otherwise.
"""
logger.debug(f"DEBUG: Starting quantisation of {input_path.name}")
logger.info(f"🔄 Quantising {input_path.name}...")
logger.debug(f"DEBUG: Input: {input_path}")
if not input_path.exists():
logger.error(f"❌ Input file does not exist: {input_path}")
return False
if not input_path.is_file():
logger.error(f"❌ Input path is not a file: {input_path}")
return False
try:
input_size = input_path.stat().st_size
logger.debug(f"DEBUG: Input file size: {input_size / (1024**3):.2f} GB")
if input_size == 0:
logger.error("❌ Input file is empty")
return False
except Exception as e:
logger.warning(f"⚠️ Could not check input file size: {e}")
return True
def _setup_quantisation_params(
self,
config: QuantisationConfig,
imatrix_path: Path | None,
) -> llama_model_quantize_params | None:
"""Setup quantisation parameters.
Args:
config: Quantisation configuration.
imatrix_path: Optional path to importance matrix.
Returns:
Configured parameters or None if setup failed.
"""
logger.debug("DEBUG: Setting up quantisation parameters")
params = llama_model_quantize_params()
# Set base quantisation type
try:
params.ftype = self.get_quantisation_type(config.base_type)
logger.debug(
f"DEBUG: Set ftype to {params.ftype} for {config.base_type} (config: {config.name})"
)
except Exception as e:
logger.error(f"❌ Failed to get quantisation type for {config.name}: {e}")
return None
# Configure basic parameters
params.nthread = 8
params.allow_requantize = True
logger.debug(
f"DEBUG: Set nthread={params.nthread}, allow_requantize={params.allow_requantize}"
)
# Add imatrix if available
if imatrix_path and imatrix_path.exists():
try:
# Convert path to bytes and create c_char_p, then cast to c_void_p
imatrix_bytes = str(imatrix_path).encode("utf-8")
char_p = ctypes.c_char_p(imatrix_bytes)
params.imatrix = ctypes.cast(char_p, ctypes.c_void_p)
logger.info(f"🧮 Using imatrix: {imatrix_path.name}")
logger.debug(f"DEBUG: imatrix path set: {imatrix_path}")
except Exception as e:
logger.error(f"❌ Failed to set imatrix: {e}")
# Continue without imatrix
# Configure tensor-specific types
logger.debug("DEBUG: Configuring tensor-specific types")
try:
self._configure_tensor_types(params, config)
logger.debug("DEBUG: Tensor types configured successfully")
except Exception as e:
logger.error(f"❌ Failed to configure tensor types: {e}")
logger.error("Tensor type configuration traceback:")
for line in traceback.format_exc().splitlines():
logger.error(f" {line}")
# Continue with default types
return params
def _execute_quantisation(
self,
input_path: Path,
output_path: Path,
params: llama_model_quantize_params,
) -> int:
"""Execute the actual quantisation with signal handling.
Args:
input_path: Path to input model.
output_path: Path for output model.
params: Configured quantisation parameters.
Returns:
Return code from quantisation (0 for success).
"""
logger.debug("DEBUG: Starting llama_cpp.llama_model_quantize call")
logger.debug("DEBUG: About to call llama_model_quantize...")
# Setup signal handlers
old_handlers = self._setup_signal_handlers()
try:
result = llama_cpp.llama_model_quantize(
str(input_path).encode("utf-8"), str(output_path).encode("utf-8"), params
)
logger.debug(f"DEBUG: llama_model_quantize returned: {result}")
except Exception as e:
logger.error(f"❌ llama_model_quantize raised exception: {e}")
logger.error("llama_model_quantize traceback:")
for line in traceback.format_exc().splitlines():
logger.error(f" {line}")
return -1
else:
return result
finally:
self._restore_signal_handlers(old_handlers)
def _setup_signal_handlers(self) -> tuple[Any, Any | None]:
"""Setup signal handlers for debugging termination.
Returns:
Tuple of (old_sigterm, old_sigsegv) handlers.
"""
def signal_debug_handler(signum: int, frame: object) -> Never: # noqa: ARG001
logger.error(f"DEBUG: Received signal {signum} during quantisation!")
logger.error(f"DEBUG: Signal name: {signal.Signals(signum).name}")
msg = f"Signal {signum} received"
raise KeyboardInterrupt(msg)
old_sigterm = signal.signal(signal.SIGTERM, signal_debug_handler)
old_sigsegv = (
signal.signal(signal.SIGSEGV, signal_debug_handler)
if hasattr(signal, "SIGSEGV")
else None
)
return old_sigterm, old_sigsegv
def _restore_signal_handlers(self, handlers: tuple[Any, Any | None]) -> None:
"""Restore original signal handlers.
Args:
handlers: Tuple of (old_sigterm, old_sigsegv) handlers.
"""
old_sigterm, old_sigsegv = handlers
signal.signal(signal.SIGTERM, old_sigterm)
if old_sigsegv is not None:
signal.signal(signal.SIGSEGV, old_sigsegv)
def _finalize_successful_quantisation(
self,
output_path: Path,
mem_before: float,
) -> bool:
"""Finalize successful quantisation and verify output.
Args:
output_path: Path to output file.
mem_before: Memory usage before quantisation in GB.
Returns:
True if output is valid, False otherwise.
"""
logger.debug("DEBUG: Quantisation returned success code")
# Verify output exists
if not output_path.exists():
logger.error(
f"❌ Quantisation claimed success but output does not exist: {output_path}"
)
return False
# Verify output size
output_size = output_path.stat().st_size
logger.debug(f"DEBUG: Output file size: {output_size / (1024**3):.2f} GB")
if output_size == 0:
logger.error("❌ Output file is empty despite success code")
return False
logger.info(f"✅ Quantisation successful: {output_path.name}")
# Force cleanup and log final state
gc.collect()
mem_after = self._log_resource_state("after")
logger.debug(f"DEBUG: Memory delta: {mem_after - mem_before:+.2f} GB")
return True
def _configure_tensor_types(
self, params: llama_model_quantize_params, config: QuantisationConfig
) -> None:
"""Configure tensor-specific quantisation types.
Sets embedding and output tensor type overrides based on config.
These are the only tensor-specific controls that work reliably
with llama-cpp-python.
"""
logger.debug(f"DEBUG: _configure_tensor_types called for {config.name}")
# Apply embedding override if specified
if config.embedding_type:
params.token_embedding_type = self.get_tensor_type_value(config.embedding_type)
logger.info(f"⚙️ Token embedding type: {config.embedding_type}")
# Apply output override if specified
if config.output_type:
params.output_tensor_type = self.get_tensor_type_value(config.output_type)
params.quantize_output_tensor = True
logger.info(f"⚙️ Output tensor type: {config.output_type}")
def convert_hf_to_gguf(
self, input_dir: Path, output_path: Path, output_type: str = "f16"
) -> bool:
"""Convert HuggingFace model to GGUF format using native Python converter.
Uses our GGUFConverter for SafeTensors models, providing full Python-based
conversion without external dependencies.
Returns:
True if conversion successful, False otherwise.
"""
logger.info(f"🔄 Converting {input_dir.name} to GGUF format...")
logger.info(f"📝 Input: {input_dir}")
logger.info(f"📝 Output: {output_path}")
logger.info(f"📝 Type: {output_type}")
# Check for SafeTensors files
safetensor_files = list(input_dir.glob("*.safetensors"))
if not safetensor_files:
logger.warning("⚠️ No SafeTensors files found in model directory")
return False
try:
# Load model configuration
config_parser = ConfigParser()
model_config = config_parser.load_model_config(input_dir)
# Get architecture mapping
arch_name = model_config.architectures[0] if model_config.architectures else "llama"
arch = config_parser.get_architecture_mapping(arch_name)
if arch != arch_name:
logger.info(f"📝 Architecture mapping: {arch_name}{arch}")
# Convert using GGUFConverter
tensor_mapper = TensorMapper()
success = GGUFConverter.convert_safetensors(
input_dir, output_path, model_config, arch, tensor_mapper
)
except Exception as e:
logger.error(f"❌ Conversion failed with exception: {e}")
return False
else:
if success:
logger.info("✅ Native Python conversion successful")
return success