Adding LoRA support (#249)

* adding training forward

* tests

* tests

* fix

* plop

* plop

* plop

* lora all but not multilinear

* all lora works+lazy comp

* minor

* minor fix

* lora ckpt loaders

* checkpointing

* checkpointing + new lora loading

* tokenizer path

* quantization back (to test)

* small fixes

* fuse_lora arg

* remove compile

* edits from alex

* edits from alex

* trying dynamic no compile

* reformat configs + loading

* wip

* wip

* fix typing

* fixing loading of old models

* some improvements

* remove param

* plop

* Lora and quant (#248)

* wip

* wip

* fix typing

* fixing loading of old models

* some improvements

* remove param

* plop

* fix

* fix

* fine-tune loading

* Import lora models for mlx.

* meta init

* exchange order

* Rust lora importer.

* update load weight

* fixing meta init

* load weight in moshi

* no post hook

* fix

* dbg

---------

Co-authored-by: hippolytepilchen <hippolyte.pilchen@gmail.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Alexandre Défossez 2025-04-01 12:06:39 +02:00 committed by GitHub
parent 5707114ca4
commit ea5401cc3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1170 additions and 284 deletions

View file

@ -31,3 +31,6 @@ runs:
run: |
sudo apt-get update
sudo apt-get install libasound2-dev
echo "test"
cmake --version
apt-cache show cmake

View file

@ -13,6 +13,12 @@ repos:
entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pyright'
pass_filenames: false
always_run: true
- id: tests-moshi
name: pytests on moshi package
language: system
entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pytest tests'
pass_filenames: false
always_run: true
- id: flake8-moshi_mlx
name: flake8 on moshi_mlx package
language: system

View file

@ -3,3 +3,4 @@ include *.md
include *.cfg
include requirements.txt
include moshi/py.typed
include tests/assets/*.safetensors

View file

@ -16,4 +16,4 @@ from . import modules
from . import quantization
from . import utils
__version__ = "0.2.3"
__version__ = "0.2.4a1"

View file

@ -16,7 +16,6 @@ import torch
from torch import nn
from ..modules.transformer import create_sin_embedding
from ..utils import quantize
logger = logging.getLogger(__name__)
@ -212,7 +211,7 @@ class BaseConditioner(nn.Module, tp.Generic[Prepared]):
cond = torch.zeros(B, T, C, device=cond.device, dtype=cond.dtype)
mask = torch.zeros_like(cond[..., 0], dtype=torch.bool)
cond = quantize.linear(self.output_proj, cond)
cond = self.output_proj(cond)
maskf = mask.float()[..., None]
if self.learnt_padding is not None:

View file

@ -13,59 +13,31 @@ from dataclasses import dataclass, field
from functools import partial
import logging
import typing as tp
import torch
from torch import nn
from ..conditioners import ConditionProvider, ConditionFuser, ConditionTensors
from ..utils.sampling import sample_token
from ..utils.compile import CUDAGraphed
from ..utils import quantize
from ..utils.quantize import replace_linear_with_qlinear
from ..modules.streaming import StreamingContainer, StreamingModule, State
from ..modules.transformer import (
StreamingTransformer,
quantize_transformer,
create_norm_fn,
)
from ..modules.transformer import StreamingTransformer, create_norm_fn
from .lm_utils import (_delay_sequence,
_undelay_sequence,
_init_layer,
ScaledEmbedding)
logger = logging.getLogger(__name__)
class ScaledEmbedding(nn.Embedding):
"""Boost learning rate for embeddings (with `scale`).
Args:
norm (bool): if True, uses a layer norm after the embedding.
zero_idx (int): special value indicating that the output should be exactly 0.
low_rank (int | None): if provided, uses low rank embedding with a linear layer to reach
the desired dimension. Quite efficient for reducing the number of weights for very large vocabs.
"""
def __init__(self, num_embeddings: int, embedding_dim: int,
*args, norm: bool = False, zero_idx: int = -1,
low_rank: int | None = None, **kwargs):
super().__init__(num_embeddings, low_rank or embedding_dim, *args, **kwargs)
self.norm = None
if norm:
self.norm = create_norm_fn("layer_norm", self.embedding_dim)
assert zero_idx < 0, "Please use negative values for the zero_idx."
self.zero_idx = zero_idx
self.low_rank = None
if low_rank is not None:
self.low_rank = nn.Linear(low_rank, embedding_dim, bias=False)
def forward(self, input, *args, **kwargs):
is_zero = input == self.zero_idx
zero = torch.zeros(1, dtype=input.dtype, device=input.device)
input = input.clamp(min=0)
y = super().forward(input, *args, **kwargs)
if self.norm is not None:
y = self.norm(y)
y = torch.where(is_zero[..., None], zero, y)
if self.low_rank is not None:
y = quantize.linear(self.low_rank, y)
return y
@dataclass
class LMOutput:
# The logits are already re-aligned with the input codes
# hence no extra shift is required, e.g. when computing CE
logits: torch.Tensor # [B, K, T, card]
mask: torch.Tensor # [B, K, T]
text_logits: torch.Tensor # [B, 1, T, text_card]
text_mask: torch.Tensor # [B, 1, T]
class LMModel(StreamingContainer):
@ -88,8 +60,7 @@ class LMModel(StreamingContainer):
depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim.
depformer_weights_per_step_schedule (list[int] | None): mapping `CODEBOOK_INDEX -> WEIGHT_INDEX`, allowing
depformer_low_rank_embeddings (int | None): if provided, uses low rank embeddings, with a linear
existing_text_padding_id (bool): if True, will use a different token for the initial text token, and
the text padding token.
existing_text_padding_id (int): token to use for the padding.
same_initial (bool): if True, uses the same initial tokens for both text and audio mode.
**kwargs: Additional parameters for the transformer encoder.
"""
@ -114,13 +85,16 @@ class LMModel(StreamingContainer):
depformer_weights_per_step_schedule: list[int] | None = None,
depformer_low_rank_embeddings: int | None = None,
depformer_pos_emb: str = "sin",
existing_text_padding_id: tp.Optional[int] = None,
existing_text_padding_id: int = 3,
existing_text_end_padding_id: int = 0,
context: tp.Optional[int] = None,
causal: bool = True,
condition_provider: tp.Optional[ConditionProvider] = None,
fuser: tp.Optional[ConditionFuser] = None,
quantize: bool = False,
device=None,
dtype=None,
gradient_checkpointing: bool = False,
**kwargs,
):
super().__init__()
@ -132,11 +106,11 @@ class LMModel(StreamingContainer):
self.delays = delays
self.dim = dim
self.existing_text_padding_id = existing_text_padding_id
self.existing_text_end_padding_id = existing_text_end_padding_id
self.context = context
self.depformer_weights_per_step_schedule = depformer_weights_per_step_schedule
if depformer_weights_per_step_schedule is not None:
assert len(depformer_weights_per_step_schedule) == dep_q
kwargs["context"] = context
EmbeddingFactory = partial(
ScaledEmbedding,
norm=norm_emb,
@ -147,11 +121,10 @@ class LMModel(StreamingContainer):
self.emb = nn.ModuleList(
[EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)]
)
# Text card + padding token (if not in the original tokenizer)
extra_text = self.existing_text_padding_id is None
# Unlike for audio, here we authorize the model to output the special token.
self.text_emb = EmbeddingFactory(text_card + 1, dim)
self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj)
self.text_linear = nn.Linear(dim, text_card, bias=bias_proj)
depformer_prefix = "depformer_"
main_kwargs = {
k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix)
@ -164,6 +137,9 @@ class LMModel(StreamingContainer):
device=device,
dtype=dtype,
quantize=quantize,
context=context,
causal=causal,
checkpointing=gradient_checkpointing,
**main_kwargs,
)
self.out_norm = create_norm_fn(norm, dim)
@ -205,7 +181,9 @@ class LMModel(StreamingContainer):
dim_feedforward=depformer_dim_feedforward,
norm=norm,
weights_per_step_schedule=depformer_weights_per_step_schedule,
causal=causal,
quantize=quantize,
checkpointing=gradient_checkpointing,
device=device,
dtype=dtype,
**kwargs_dep,
@ -221,8 +199,9 @@ class LMModel(StreamingContainer):
self.condition_provider = condition_provider
self.fuser = fuser
self.to(device=device, dtype=dtype)
self._init_weights()
if quantize:
quantize_transformer(self)
replace_linear_with_qlinear(self)
@property
def initial_token_id(self) -> int:
@ -237,15 +216,12 @@ class LMModel(StreamingContainer):
@property
def text_padding_token_id(self) -> int:
"""Token id for text padding."""
if self.existing_text_padding_id is None:
return self.text_card
else:
return self.existing_text_padding_id
return self.existing_text_padding_id
@property
def end_of_text_padding_id(self) -> int:
"""Token id for optionally marking the last padding step for a word."""
return 0
return self.existing_text_end_padding_id
@property
def zero_token_id(self) -> int:
@ -294,6 +270,61 @@ class LMModel(StreamingContainer):
token = torch.cat([text_token, audio_token], dim=1)
return token
def forward(
self, codes: torch.Tensor,
condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
"""Given an input tensor of codes [B, K, T] and list of conditions, returns the logits
along with masks indicating the valid positions at which to compute the loss.
The logits time steps are aligned with those in the input `code`.
Should only be used for training, not inference (use `LMGen` for that).
Args:
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
K the number of codebooks and T the number of timesteps. When text is supported,
the first 'codebook' corresponds to the text, and the remaining codebooks are for the audio.
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning tensors.
Returns:
LMOutput: Language model outputs, containing either text or audio logits, or both.
logits (torch.Tensor, or None) of shape [B, K, T, card] corresponding to the provided codes,
i.e. the first item corresponds to logits to predict the first code, meaning that
no additional shifting of codes and logits is required.
mask (torch.Tensor, or None) of shape [B, K, T], mask over valid and invalid positions.
Given the specified interleaving strategies, parts of the logits and codes should
not be considered as valid predictions because of invalid context.
text_logits (torch.Tensor, or None) of shape [B, 1, T, text_card].
text_mask (torch.Tensor, or None) of shape [B, 1, T], mask over the valid positions for the text.
"""
B, K, T = codes.shape
assert K == self.num_codebooks, (K, self.num_codebooks)
# Delaying codes and removing the last time step that will never be an input.
initial = self._get_initial_token().expand(B, -1, -1)
delayed_codes = _delay_sequence(self.delays, codes, initial)
# Inserting the empty tokens for the first time step.
delayed_codes = torch.cat([initial, delayed_codes], dim=2)
sum_condition: torch.Tensor | None = None
if condition_tensors is None:
assert self.fuser is None
else:
assert self.fuser is not None
sum_condition = self.fuser.get_sum(condition_tensors)
transformer_out, text_logits = self.forward_text(delayed_codes[:, :, :-1], sum_condition)
assert transformer_out.shape[0] == delayed_codes.shape[0]
assert transformer_out.shape[1] == delayed_codes.shape[2] - 1
logits = self.forward_depformer_training(delayed_codes[:, :, 1:], transformer_out)
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
# and provide the corresponding mask over invalid positions of tokens. We will with NaN values invalid positions
# to ensure they properly handled.
logits, logits_mask = _undelay_sequence(
self.delays[self.audio_offset:self.audio_offset + self.dep_q],
logits, fill_value=float('NaN'))
logits_mask &= (codes[:, self.audio_offset: self.audio_offset + self.dep_q] != self.zero_token_id)
text_logits, text_logits_mask = _undelay_sequence(self.delays[:1], text_logits, fill_value=float('NaN'))
text_logits_mask &= (codes[:, :1] != self.zero_token_id)
return LMOutput(logits, logits_mask, text_logits, text_logits_mask)
def forward_text(
self,
sequence: torch.Tensor, sum_condition: torch.Tensor | None = None
@ -310,18 +341,54 @@ class LMModel(StreamingContainer):
)
input_ = audio_emb if input_ is None else input_ + audio_emb
text_emb = self.text_emb(input_sequence[:, 0])
input_ = text_emb if input_ is None else input_ + text_emb
if sum_condition is not None:
input_ = input_ + sum_condition.to(input_)
transformer_out = self.transformer(input_)
if self.out_norm:
transformer_out = self.out_norm(transformer_out)
assert isinstance(transformer_out, torch.Tensor)
text_logits = quantize.linear(self.text_linear, transformer_out)
text_logits = self.text_linear(transformer_out)
text_logits = text_logits[:, None]
return transformer_out, text_logits
def forward_depformer_training(
self,
sequence: torch.Tensor,
transformer_out: torch.Tensor,
) -> torch.Tensor:
B, K, T = sequence.shape
Ka = self.dep_q
assert (
K == self.num_codebooks
), f"Codebooks for Depformer training should be passed all at once, got {K,}."
depformer_inputs = []
for cb_index in range(Ka):
if self.depformer_multi_linear:
linear_index = cb_index
if self.depformer_weights_per_step_schedule is not None:
linear_index = self.depformer_weights_per_step_schedule[cb_index]
transformer_in = self.depformer_in[linear_index](transformer_out)
else:
transformer_in = self.depformer_in[0](transformer_out)
if cb_index == 0:
token_in = self.depformer_text_emb(sequence[:, 0])
else:
token_in = self.depformer_emb[cb_index - 1](sequence[:, cb_index + self.audio_offset - 1])
depformer_inputs.append(token_in + transformer_in)
depformer_input = torch.stack(depformer_inputs, 2)
# depformer_input is [B, T, K, depformer_dim], reshaping to [B * T, K, D]
depformer_input = depformer_input.view(B * T, Ka, -1)
depformer_output = self.depformer(depformer_input)
all_logits = []
for cb_index in range(Ka):
logits = self.linears[cb_index](depformer_output[:, cb_index])
all_logits.append(logits.view(B, T, -1))
logits = torch.stack(all_logits, 1)
assert logits.dim() == 4, logits.shape # [B, Ka, T, card]
return logits
def forward_depformer(
self,
depformer_cb_index: int,
@ -344,9 +411,9 @@ class LMModel(StreamingContainer):
in_index = depformer_cb_index
if self.depformer_weights_per_step_schedule is not None:
in_index = self.depformer_weights_per_step_schedule[in_index]
depformer_input = quantize.linear(self.depformer_in[in_index], depformer_input)
depformer_input = self.depformer_in[in_index](depformer_input)
else:
depformer_input = quantize.linear(self.depformer_in[0], depformer_input)
depformer_input = self.depformer_in[0](depformer_input)
if depformer_cb_index == 0:
last_token_input = self.depformer_text_emb(sequence[:, 0])
else:
@ -359,11 +426,35 @@ class LMModel(StreamingContainer):
# depformer_input is [B, 1, depformer_dim].
# The streaming state of the depformer ensures that the proper layer is run.
dep_output = self.depformer(depformer_input)
logits = quantize.linear(self.linears[depformer_cb_index], dep_output)
logits = self.linears[depformer_cb_index](dep_output)
logits = logits[:, None]
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
return logits
def _init_weights(self):
"""Initialization of the transformer module weights.
Mostly truncated gaussian, with `std = 1 / sqrt(dim_in)`.
Embeddings are also initialized with `1 / sqrt(dim)` rather than `1`.
Some layers are not going to be properly initialized:
- in_proj in MHA.
- depth transformer layers.
This is to match how our models were trained so far.
"""
for emb_layer in self.emb:
_init_layer(emb_layer)
for emb_layer in self.depformer_emb:
_init_layer(emb_layer)
_init_layer(self.text_emb)
_init_layer(self.depformer_text_emb)
_init_layer(self.text_linear)
for tr_layer in self.transformer.layers:
tr_layer.apply(_init_layer)
for linear in self.linears:
_init_layer(linear)
@dataclass
class _LMGenState(State):
@ -482,7 +573,7 @@ class LMGen(StreamingModule[_LMGenState]):
k = lm_model.dep_q + 1 + q_other
delay = lm_model.delays[k]
write_position = (state.offset + delay) % CT
state.cache[:, k, write_position : write_position + 1] = input_tokens[
state.cache[:, k, write_position: write_position + 1] = input_tokens[
:, q_other
]
@ -492,7 +583,7 @@ class LMGen(StreamingModule[_LMGenState]):
# token that are delayed, and thus have no good value to take.
if state.offset <= delay:
state.cache[:, k, position] = state.initial[:, k, 0]
input_ = state.cache[:, :, position : position + 1]
input_ = state.cache[:, :, position: position + 1]
if self.check:
# Check that we are not feeding in any value that is not generated yet.
@ -526,7 +617,7 @@ class LMGen(StreamingModule[_LMGenState]):
state.offset += 1
position = state.offset % CT
state.cache[:, 0, position] = text_token
state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens
state.cache[:, 1:lm_model.dep_q + 1, position] = audio_tokens
if state.offset <= self.max_delay:
return None

View file

@ -0,0 +1,107 @@
import math
import typing as tp
import torch
from torch import nn
from ..modules.transformer import create_norm_fn
def _delay_sequence(delays: tp.List[int], tensor: torch.Tensor, padding: torch.Tensor) -> torch.Tensor:
B, K, T = tensor.shape
assert len(delays) == K, (len(delays), K)
outs = []
for k, delay in enumerate(delays):
assert delay >= 0
line = tensor[:, k].roll(delay, dims=1)
if delay > 0:
line[:, :delay] = padding[:, k]
outs.append(line)
return torch.stack(outs, dim=1)
def _undelay_sequence(delays: tp.List[int], tensor: torch.Tensor,
fill_value: tp.Union[int, float] = float('NaN')) -> tp.Tuple[torch.Tensor, torch.Tensor]:
B, K, T, *_ = tensor.shape
assert len(delays) == K
mask = torch.ones(B, K, T, dtype=torch.bool, device=tensor.device)
outs = []
if all([delay == 0 for delay in delays]):
return tensor, mask
for k, delay in enumerate(delays):
assert delay >= 0
line = tensor[:, k].roll(-delay, dims=1)
if delay > 0:
line[:, -delay:] = fill_value
mask[:, k, -delay:] = 0
outs.append(line)
return torch.stack(outs, dim=1), mask
def _get_init_fn(input_dim: int) -> tp.Callable[[torch.Tensor], None]:
def _init(x: torch.Tensor) -> None:
std = 1 / math.sqrt(input_dim)
x_orig = x
if x.device.type == 'cpu' and x.dtype in [torch.float16, torch.bfloat16]:
x = x.float()
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
if x_orig is not x:
x_orig.data[:] = x.to(x_orig)
return _init
def _init_layer(m: nn.Module,
zero_bias_init: bool = True):
if isinstance(m, nn.Linear):
init_fn = _get_init_fn(m.in_features)
init_fn(m.weight)
if zero_bias_init and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
init_fn = _get_init_fn(m.embedding_dim)
init_fn(m.weight)
class ScaledEmbedding(nn.Embedding):
"""Boost learning rate for embeddings (with `scale`).
Args:
norm (bool): if True, uses a layer norm after the embedding.
zero_idx (int): special value indicating that the output should be exactly 0.
low_rank (int | None): if provided, uses low rank embedding with a linear layer to reach
the desired dimension. Quite efficient for reducing the number of weights for very large vocabs.
lr (float or None): learning rate to use, only valid if the `make_optim_group()` method is used.
"""
def __init__(self, num_embeddings: int, embedding_dim: int,
*args, norm: bool = False, zero_idx: int = -1,
low_rank: int | None = None, lr: float | None = None, **kwargs):
super().__init__(num_embeddings, low_rank or embedding_dim, *args, **kwargs)
self.norm = None
if norm:
self.norm = create_norm_fn("layer_norm", self.embedding_dim)
assert zero_idx < 0, "Please use negative values for the zero_idx."
self.zero_idx = zero_idx
self.lr = lr
self.low_rank = None
if low_rank is not None:
self.low_rank = nn.Linear(low_rank, embedding_dim, bias=False)
def forward(self, input, *args, **kwargs):
is_zero = input == self.zero_idx
zero = torch.zeros(1, dtype=input.dtype, device=input.device)
input = input.clamp(min=0)
y = super().forward(input, *args, **kwargs)
if self.norm is not None:
y = self.norm(y)
y = torch.where(is_zero[..., None], zero, y)
if self.low_rank is not None:
y = self.low_rank(y)
return y
def make_optim_group(self) -> dict:
group: dict[str, tp.Any] = {"params": list(self.parameters())}
if self.lr is not None:
group["lr"] = self.lr
return group

View file

@ -2,36 +2,37 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Retrieves the pretrained models for Moshi and Mimi."""
from dataclasses import dataclass, field
import json
from pathlib import Path
import warnings
from huggingface_hub import hf_hub_download
try:
from huggingface_hub.errors import EntryNotFoundError
except ImportError:
from huggingface_hub.utils import EntryNotFoundError # pyright: ignore
from safetensors.torch import load_model
from safetensors.torch import load_model, load_file
import sentencepiece
import torch
import typing as tp
from .compression import MimiModel
from ..conditioners import BaseConditioner, ConditionProvider, ConditionFuser
from .lm import LMModel
from ..modules import SEANetEncoder, SEANetDecoder, transformer
from ..quantization import SplitResidualVectorQuantizer
from ..modules.lora import replace_all_linear_with_lora, replace_lora_with_linear
SAMPLE_RATE = 24000
FRAME_RATE = 12.5
TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model'
MOSHI_NAME = 'model.safetensors'
MOSHI_Q8_NAME = 'model.q8.safetensors'
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
TEXT_TOKENIZER_NAME = "tokenizer_spm_32k_3.model"
MOSHI_NAME = "model.safetensors"
MOSHI_Q8_NAME = "model.q8.safetensors"
MIMI_NAME = "tokenizer-e351c8d8-checkpoint125.safetensors"
DEFAULT_REPO = "kyutai/moshiko-pytorch-bf16"
_seanet_kwargs = {
@ -99,7 +100,6 @@ _lm_kwargs = {
"depformer_dim_feedforward": int(4.125 * 1024),
"depformer_num_heads": 16,
"depformer_num_layers": 6,
"depformer_causal": True,
"depformer_layer_scale": None,
"depformer_multi_linear": True,
"depformer_context": 8,
@ -139,20 +139,25 @@ class CheckpointInfo:
raw_config: raw config, including original keys not intended for the LM.
model_type: indicate the intended use, should be `moshi` or `hibiki`.
"""
moshi_weights: Path
mimi_weights: Path
tokenizer: Path
lm_config: dict | None = None
raw_config: dict | None = None
model_type: str = 'moshi'
model_type: str = "moshi"
lora_weights: Path | None = None
lm_gen_config: dict = field(default_factory=dict)
@staticmethod
def from_hf_repo(hf_repo: str,
moshi_weights: Path | str | None = None,
mimi_weights: Path | str | None = None,
tokenizer: Path | str | None = None,
config_path: Path | str | None = None) -> 'CheckpointInfo':
def from_hf_repo(
hf_repo: str,
moshi_weights: Path | str | None = None,
mimi_weights: Path | str | None = None,
tokenizer: Path | str | None = None,
config_path: Path | str | None = None,
lora_weights: Path | str | None = None,
) -> "CheckpointInfo":
"""Downloads the checkpoints from the given repo, along with its config.
Extra overrides are possible for each of Moshi, Mimi, or the text tokenizer,
@ -163,28 +168,32 @@ class CheckpointInfo:
"""
if config_path is None:
try:
config_path = hf_hub_download(hf_repo, 'config.json')
config_path = hf_hub_download(hf_repo, "config.json")
except EntryNotFoundError:
# No config.json, which might indicate legacy repository.
warnings.warn(f"Repository {hf_repo} contains no config.json. "
"Assuming this is a Moshi 7B. Support for such repository "
"might be removed in the future.")
warnings.warn(
f"Repository {hf_repo} contains no config.json. "
"Assuming this is a Moshi 7B. Support for such repository "
"might be removed in the future."
)
if config_path is None:
moshi_name = MOSHI_NAME
mimi_name = MIMI_NAME
tokenizer_name = TEXT_TOKENIZER_NAME
lm_config = None
raw_config = None
model_type = 'moshi'
model_type = "moshi"
lm_gen_config = {}
lora_name = None
else:
raw_config = json.loads(Path(config_path).read_text())
lm_config = dict(raw_config)
moshi_name = lm_config.pop('moshi_name', MOSHI_NAME)
mimi_name = lm_config.pop('mimi_name', MIMI_NAME)
tokenizer_name = lm_config.pop('tokenizer_name', TEXT_TOKENIZER_NAME)
model_type = lm_config.pop('model_type', 'moshi')
lm_gen_config = lm_config.pop('lm_gen_config', {})
moshi_name = lm_config.pop("moshi_name", MOSHI_NAME)
mimi_name = lm_config.pop("mimi_name", MIMI_NAME)
tokenizer_name = lm_config.pop("tokenizer_name", TEXT_TOKENIZER_NAME)
lora_name = lm_config.pop("lora_name", None)
model_type = lm_config.pop("model_type", "moshi")
lm_gen_config = lm_config.pop("lm_gen_config", {})
if moshi_weights is None:
moshi_weights_final = hf_get(moshi_name, hf_repo)
@ -201,23 +210,47 @@ class CheckpointInfo:
else:
tokenizer_final = hf_get(tokenizer)
return CheckpointInfo(
moshi_weights_final, mimi_weights_final, tokenizer_final,
lm_config, raw_config, model_type, lm_gen_config)
if lora_weights is None and lora_name:
lora_weights_final = hf_get(lora_name, hf_repo)
elif lora_weights is not None:
lora_weights_final = hf_get(lora_weights)
else:
lora_weights_final = None
def get_mimi(self, device: torch.device | str = 'cpu') -> MimiModel:
return CheckpointInfo(
moshi_weights_final,
mimi_weights_final,
tokenizer_final,
lm_config,
raw_config,
model_type,
lora_weights_final,
lm_gen_config=lm_gen_config,
)
def get_mimi(self, device: torch.device | str = "cpu") -> MimiModel:
if self.lm_config is None:
num_codebooks = 8
else:
num_codebooks = self.lm_config['dep_q']
num_codebooks = self.lm_config["dep_q"]
return get_mimi(self.mimi_weights, num_codebooks=num_codebooks, device=device)
def get_moshi(self, strict: bool = True, device: torch.device | str = 'cpu',
dtype: torch.dtype = torch.bfloat16) -> LMModel:
def get_moshi(
self,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.bfloat16,
load_weight: bool = True,
**kwargs,
) -> LMModel:
model = get_moshi_lm(
self.moshi_weights, lm_kwargs=self.lm_config,
device=device, dtype=dtype, strict=strict)
if self.model_type == 'hibiki':
self.moshi_weights if load_weight else None,
lm_kwargs=self.lm_config,
device=device,
dtype=dtype,
lora_weights=self.lora_weights,
**kwargs,
)
if self.model_type == "hibiki":
# Sometime the model samples the EOS (2) too early, which we want to ignore.
# We keep generating if the input file is not finished, and this is a way
# to implicitely replace early EOS with PAD.
@ -232,9 +265,9 @@ def _is_safetensors(path: Path | str) -> bool:
return Path(path).suffix in (".safetensors", ".sft", ".sfts")
def get_mimi(filename: str | Path,
device: torch.device | str = 'cpu',
num_codebooks: int = 8) -> MimiModel:
def get_mimi(
filename: str | Path, device: torch.device | str = "cpu", num_codebooks: int = 8
) -> MimiModel:
"""Return a pretrained Mimi model."""
encoder = SEANetEncoder(**_seanet_kwargs)
decoder = SEANetDecoder(**_seanet_kwargs)
@ -262,7 +295,7 @@ def get_mimi(filename: str | Path,
).to(device=device)
model.eval()
if _is_safetensors(filename):
load_model(model, filename)
load_model(model, filename, device=str(device))
else:
pkg = torch.load(filename, "cpu")
model.load_state_dict(pkg["model"])
@ -270,46 +303,99 @@ def get_mimi(filename: str | Path,
return model
def get_moshi_lm(filename: str | Path,
lm_kwargs: tp.Optional[tp.Dict] = None,
device: torch.device | str = 'cpu',
dtype: torch.dtype = torch.bfloat16,
strict: bool = True) -> LMModel:
def get_moshi_lm(
filename: str | Path | None,
lm_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.bfloat16,
lora_weights: str | Path | None = None,
fuse_lora: bool = False,
lm_kwargs_overrides={},
) -> LMModel:
if lm_kwargs is None:
lm_kwargs = _lm_kwargs
lm_kwargs = dict(lm_kwargs)
assert lm_kwargs is not None
if "conditioners" in lm_kwargs:
lm_kwargs["condition_provider"] = get_conditioner_provider(lm_kwargs["dim"], device, lm_kwargs)
lm_kwargs["condition_provider"] = get_conditioner_provider(
lm_kwargs["dim"], device, lm_kwargs
)
del lm_kwargs["conditioners"]
if "fuser" in lm_kwargs:
lm_kwargs["fuser"] = get_condition_fuser(lm_kwargs)
lm_kwargs = lm_kwargs | lm_kwargs_overrides
assert lm_kwargs is not None
# deprecated params.
lm_kwargs.pop("depformer_causal", None)
# lora params.
lora = lm_kwargs.pop("lora", False)
lora_rank = lm_kwargs.pop("lora_rank", 128)
lora_scaling = lm_kwargs.pop("lora_scaling", 2.0)
init_device = device
if filename is not None:
init_device = torch.device('meta')
model = LMModel(
device=device,
device=init_device,
dtype=dtype,
**lm_kwargs)
model.eval()
if _is_safetensors(filename):
load_model(model, filename, strict=strict)
else:
pkg = torch.load(
filename,
"cpu",
if filename is not None:
if _is_safetensors(filename):
state = load_file(filename, device=str(device))
for key, value in state.items():
if value.dtype.is_floating_point:
value = value.to(dtype=dtype)
state[key] = value
model.load_state_dict(state, assign=True)
else:
pkg = torch.load(filename, "cpu",)
model.load_state_dict(pkg["fsdp_best_state"]["model"], assign=True)
if lora:
assert not lm_kwargs.get("quantize"), (
"LoRA and quantization are incompatible for now."
)
model.load_state_dict(pkg["fsdp_best_state"]["model"])
model = get_lora_moshi(
model=model,
lora_rank=lora_rank,
lora_scaling=lora_scaling,
lora_weights=lora_weights,
device=device,
dtype=dtype,
fuse_lora=fuse_lora,
)
else:
assert lora_weights is None, (
"`lora` is False, but received some lora_weights to load."
)
model.eval()
return model
def get_conditioner(output_dim: int, device: torch.device | str, conditioner_cfg: dict) -> BaseConditioner:
def get_conditioner(
output_dim: int, device: torch.device | str, conditioner_cfg: dict
) -> BaseConditioner:
conditioner_type = conditioner_cfg["type"]
conditioner_kwargs = conditioner_cfg[conditioner_type]
conditioner_kwargs.update({'output_dim': output_dim, 'device': device})
if conditioner_type == 'lut':
conditioner_kwargs.update({"output_dim": output_dim, "device": device})
if conditioner_type == "lut":
from ..conditioners.text import LUTConditioner
return LUTConditioner(**conditioner_kwargs)
else:
raise RuntimeError(f"Unknow conditioner type {conditioner_type}.")
def get_conditioner_provider(output_dim: int, device: torch.device | str, cfg: dict) -> ConditionProvider:
def get_conditioner_provider(
output_dim: int, device: torch.device | str, cfg: dict
) -> ConditionProvider:
"""Instantiate a conditioning model."""
conditioners: tp.Dict[str, BaseConditioner] = {}
for cond, cond_cfg in cfg["conditioners"].items():
@ -321,8 +407,39 @@ def get_conditioner_provider(output_dim: int, device: torch.device | str, cfg: d
def get_condition_fuser(cfg: dict) -> ConditionFuser:
"""Instantiate a condition fuser object."""
fuser_cfg = cfg["fuser"]
fuser_methods = ['sum', 'cross', 'prepend']
fuser_methods = ["sum", "cross", "prepend"]
fuse2cond = {k: fuser_cfg.get(k, []) for k in fuser_methods}
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
return fuser
def get_lora_moshi(
model: LMModel,
lora_weights: str | Path | None,
lora_rank: int,
lora_scaling: float,
dtype: torch.dtype = torch.bfloat16,
device: torch.device | str = "cpu",
fuse_lora: bool = True,
) -> LMModel:
init_device = device
if lora_weights is not None:
init_device = torch.device('meta')
replace_all_linear_with_lora(model, lora_rank, lora_scaling, device=init_device)
if lora_weights is not None:
assert _is_safetensors(lora_weights), "LoRA weights must be a safetensors file."
lora_state_dict = load_file(lora_weights, device=str(device))
for key, value in lora_state_dict.items():
if value.dtype.is_floating_point:
value = value.to(dtype=dtype)
lora_state_dict[key] = value
res = model.load_state_dict(lora_state_dict, strict=False, assign=True)
if res.unexpected_keys:
raise RuntimeError(
f"unexpected_keys in the lora weights: {res.unexpected_keys}"
)
model = model.to(dtype=dtype, device=device)
if fuse_lora:
replace_lora_with_linear(model)
return model

View file

@ -2,12 +2,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import ExitStack
import torch
from torch import nn
from torch.nn import functional as F
from ..utils.compile import torch_compile_lazy
from ..utils import quantize
from ..utils.compile import torch_compile_lazy, no_compile
@torch_compile_lazy
@ -22,6 +22,20 @@ def gating_forward_kernel(
return x
def gating_forward_generic(
linear_in: nn.Module,
linear_out: nn.Module,
activation,
x: torch.Tensor
):
x = linear_in(x)
B, T, _ = x.shape
x = x.view(B, T, 2, -1)
x = activation(x[..., 0, :]) * x[..., 1, :]
x = linear_out(x)
return x
class ActivationGating(nn.Module):
"""
Gating FFN layer, using the given activation.
@ -42,23 +56,30 @@ class ActivationGating(nn.Module):
hidden = (21 * dim) // 8
else:
hidden = (2 * dim_feedforward) // 3
self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
# We try to follow the default PyTorch MHA convention, to easily compare results.
self.activation = activation
def forward(self, x: torch.Tensor):
if quantize.is_quantized(self.linear_in):
assert quantize.is_quantized(self.linear_out)
x = quantize.linear(self.linear_in, x)
B, T, _ = x.shape
x = x.view(B, T, 2, -1)
x = self.activation(x[..., 0, :]) * x[..., 1, :]
x = quantize.linear(self.linear_out, x)
return x
return gating_forward_kernel(
self.linear_in.weight, self.linear_out.weight, self.activation, x
)
if isinstance(self.linear_in, nn.Linear):
assert isinstance(self.linear_out, nn.Linear)
with ExitStack() as stack:
if self.training:
stack.enter_context(no_compile())
return gating_forward_kernel(
self.linear_in.weight, self.linear_out.weight, self.activation, x
)
else:
return gating_forward_generic(
self.linear_in,
self.linear_out,
self.activation,
x
)
def _get_activation(name: str):
@ -73,7 +94,8 @@ def _get_activation(name: str):
def _make_gating(
name: str, dim: int, dim_feedforward: int, **factory_kwargs
name: str, dim: int, dim_feedforward: int,
**factory_kwargs
) -> nn.Module:
return ActivationGating(
dim, dim_feedforward, _get_activation(name), **factory_kwargs
@ -84,9 +106,10 @@ def make_gating(
name: str, dim: int, dim_feedforward: int, **factory_kwargs
) -> nn.Module:
gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs)
max_params = 2 * dim * dim_feedforward
params = sum(p.numel() for p in gating.parameters())
assert (
params <= max_params
), f"{name} gating has {params} params, max is {max_params}"
if isinstance(gating.linear_in, nn.Linear):
max_params = 2 * dim * dim_feedforward
params = sum(p.numel() for p in gating.parameters())
assert (
params <= max_params
), f"{name} gating has {params} params, max is {max_params}"
return gating

122
moshi/moshi/modules/lora.py Normal file
View file

@ -0,0 +1,122 @@
import torch
import torch.nn as nn
def replace_all_linear_with_lora(module, rank: int, scaling: float, device=None, dtype=None):
""" Recursively replace all Linear layers with LoRALinear layers."""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if device is None:
this_device = child.weight.device
else:
this_device = device
if dtype is None:
this_dtype = child.weight.dtype
else:
this_dtype = dtype
lora = LoRALinear(child.in_features, child.out_features,
rank, scaling, device=this_device, dtype=this_dtype)
lora.frozen_W = child
setattr(module, name, lora)
else:
replace_all_linear_with_lora(child, rank, scaling, device=device, dtype=dtype)
def replace_lora_with_linear(module):
"""Recursively replace all LoRALinear layers with Linear layers."""
for name, child in module.named_children():
if isinstance(child, LoRALinear):
# Compute merged weights: W' = W + scaling * B @ A
merged_weight = child.frozen_W.weight.data + \
child.scaling * (child.lora_B.weight @ child.lora_A.weight)
# Create a standard Linear layer with the same in/out features
new_linear = nn.Linear(child.frozen_W.in_features,
child.frozen_W.out_features, bias=False,
device=torch.device('meta'),
dtype=merged_weight.dtype)
new_linear.weight = nn.Parameter(
merged_weight, requires_grad=merged_weight.requires_grad) # Transfer merged weights
setattr(module, name, new_linear) # Replace the module
else:
replace_lora_with_linear(child) # Recursively process submodules
class LoRALinear(nn.Module):
"""
Implementation of:
- LoRA: https://arxiv.org/abs/2106.09685
Notes:
- Freezing is handled at the network level, not the layer level.
- Scaling factor controls relative importance of LoRA skip
connection versus original frozen weight. General guidance is
to keep it to 2.0 and sweep over learning rate when changing
the rank.
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int,
scaling: float,
bias: bool = False,
device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
assert not bias
self.bias = bias
self.rank = rank
self.scaling = scaling
self.lora_A = nn.Linear(
self.in_features,
self.rank,
bias=self.bias,
device=device,
dtype=dtype,
)
self.lora_B = nn.Linear(
self.rank,
self.out_features,
bias=self.bias,
device=device,
dtype=dtype,
)
self.frozen_W = nn.Linear(self.in_features,
self.out_features,
bias=self.bias,
device=device,
dtype=dtype)
self._register_load_state_dict_pre_hook(LoRALinear._load_hook, with_module=True)
def merge_weight(self):
with torch.no_grad():
down_weight = self.lora_A.weight
up_weight = self.lora_B.weight
weight = up_weight.mm(down_weight) * self.scaling
weight += self.frozen_W.weight
return weight
@staticmethod
def _load_hook(module, state_dict, prefix, *_):
key_name = prefix + "weight"
if key_name in state_dict:
w_ref = state_dict.pop(key_name)
state_dict[prefix + 'frozen_W.weight'] = w_ref
def forward(self, x: torch.Tensor):
lora = self.lora_B(self.lora_A(x))
return self.frozen_W(x) + lora * self.scaling
def __repr__(self) -> str:
return "{}Linear(in_features={}, out_features={}, r={})".format(
"LoRA", self.in_features, self.out_features, self.rank)

View file

@ -12,25 +12,18 @@ See `StreamingTransformer` for more information.
from contextlib import ExitStack
from dataclasses import dataclass
import typing as tp
from einops import rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F
from ..utils.compile import no_compile
from ..utils import quantize
from ..utils.quantize import replace_linear_with_qlinear
from .gating import make_gating
from .rope import RotaryEmbedding
from .streaming import StreamingModule, StreamingContainer, State
def quantize_transformer(module: torch.nn.Module):
for name, child in module.named_modules():
if isinstance(child, torch.nn.Linear):
quantize.quantize_linear(child)
elif isinstance(child, StreamingMultiheadAttention):
quantize.quantize_param(child, 'in_proj_weight')
from .lora import LoRALinear
from torch.utils.checkpoint import checkpoint as torch_checkpoint
class LayerNormF32(nn.LayerNorm):
@ -260,6 +253,34 @@ class RingKVCache:
return KVCacheResult(keys, values, positions)
def apply_weights_per_step(modules: nn.ModuleList, schedule: list[int] | None,
x: torch.Tensor, offset: int) -> torch.Tensor:
"""Utility to apply a multi linear layer to the given input. A multi linear layer
applies a different set of weight for each time step.
Args:
modules (nn.ModuleList): apply weights per step.
schedule (list[int] or None): schedule for weight sharing.
x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
offset (int): offset for the current time step, in particular for decoding, with
time steps provided one by one.
"""
if len(modules) == 1:
return modules[0](x)
ys: list[torch.Tensor] = []
B, T, C = x.shape
for t in range(T):
module_index = t + offset
if schedule is not None:
module_index = schedule[module_index]
y = modules[module_index](x[:, t: t + 1])
ys.append(y)
out = torch.cat(ys, 1)
return out
@dataclass
class _MHAState(State):
kv_cache: RingKVCache
@ -316,7 +337,6 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
self.weights_per_step = weights_per_step
self.weights_per_step_schedule = weights_per_step_schedule
out_dim = embed_dim
out_dim = 3 * embed_dim
mult = 1
if weights_per_step:
@ -325,13 +345,49 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
mult = max(weights_per_step_schedule) + 1
else:
mult = weights_per_step
in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs)
# We try to follow the default PyTorch MHA convention, to easily compare results.
self.in_proj_weight = in_proj.weight
self.in_proj_bias = in_proj.bias
self.out_proj = nn.Linear(
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
self.mult = mult
# Split in one linear per step
self.out_projs = nn.ModuleList(
[
nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
for _ in range(mult)
]
)
self.in_projs = nn.ModuleList(
[
nn.Linear(embed_dim, out_dim, bias=False, **factory_kwargs)
for _ in range(mult)
]
)
self._register_load_state_dict_pre_hook(StreamingMultiheadAttention._load_hook, with_module=True)
@staticmethod
def _load_hook(module, state_dict, prefix, *_):
mappings = {
'in_proj_weight': 'in_projs.{i}.weight',
'in_proj.weight': 'in_projs.{i}.weight',
'in_proj.lora_A.weight': 'in_projs.{i}.lora_A.weight',
'in_proj.lora_B.weight': 'in_projs.{i}.lora_B.weight',
'out_proj.weight': 'out_projs.{i}.weight',
'out_proj.lora_A.weight': 'out_projs.{i}.lora_A.weight',
'out_proj.lora_B.weight': 'out_projs.{i}.lora_B.weight',
}
mult = module.mult
# _scb suffix is for quantized data.
for suffix in ['', '_scb']:
for source, target in mappings.items():
this_source = prefix + source + suffix
if this_source in state_dict:
weight = state_dict[this_source]
_, *OD = weight.shape
weight = weight.view(mult, -1, *OD)
for i in range(mult):
this_target = prefix + target.format(i=i) + suffix
state_dict[this_target] = weight[i]
state_dict.pop(this_source)
def _init_streaming_state(self, batch_size: int) -> _MHAState:
if self.context is None:
@ -343,13 +399,20 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
)
else:
capacity = self.context
device = self.in_proj_weight.device
# TODO: the following estimation will not work great with FSDP.
if quantize.is_quantized(self, 'in_proj_weight'):
# We are running with quantization
in_proj = self.in_projs[0]
if isinstance(in_proj, LoRALinear):
device = in_proj.lora_A.weight.device
dtype = in_proj.lora_A.weight.dtype
elif isinstance(in_proj, nn.Linear):
device = in_proj.weight.device
dtype = in_proj.weight.dtype
elif isinstance(in_proj, quantize.QLinear):
device = in_proj.weight.device
dtype = torch.float16
else:
dtype = self.in_proj_weight.dtype
raise RuntimeError(f"Unknown type {type(in_proj)} for linear.")
dim_per_head = self.embed_dim // self.num_heads
kv_cache = RingKVCache(
batch_size, self.num_heads, dim_per_head, capacity, device, dtype
@ -379,16 +442,12 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
offset = state.offset
offset_cpu = state.offset_cpu
if self.weights_per_step:
projected = quantize.multi_linear(
self.weights_per_step, self.weights_per_step_schedule,
self, query, offset_cpu, name='in_proj_weight')
else:
projected = quantize.linear(self, query, 'in_proj_weight')
projected = apply_weights_per_step(
self.in_projs, self.weights_per_step_schedule, query, offset_cpu)
q, k, v = rearrange(
projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads
)
if self.rope:
q, k = self.rope(q, k, offset, time_before_heads=False)
@ -407,12 +466,9 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
x = rearrange(x, "b h t d -> b t (h d)")
if self.weights_per_step:
x = quantize.multi_linear(
self.weights_per_step, self.weights_per_step_schedule,
self.out_proj, x, offset_cpu)
else:
x = quantize.linear(self.out_proj, x)
x = apply_weights_per_step(
self.out_projs, self.weights_per_step_schedule, x, offset_cpu)
if state is not None:
state.offset.add_(T)
state.offset_cpu += T
@ -565,19 +621,11 @@ class StreamingTransformerLayer(StreamingModule[_LayerState]):
if self.gating is None:
assert self.linear1 is not None
assert self.linear2 is not None
update = quantize.linear(self.linear2, self.activation(quantize.linear(self.linear1, x)))
update = self.linear2(self.activation(self.linear1(x)))
else:
if self.weights_per_step:
assert isinstance(self.gating, nn.ModuleList)
B, T, D = x.shape
ys = []
for t in range(T):
linear_index = offset + t
if self.weights_per_step_schedule:
linear_index = self.weights_per_step_schedule[linear_index]
y = self.gating[linear_index](x[:, t:t + 1])
ys.append(y)
update = torch.cat(ys, dim=1)
update = apply_weights_per_step(self.gating, self.weights_per_step_schedule, x, offset)
else:
update = self.gating(x)
return x_orig.to(update) + self.layer_scale_2(update)
@ -645,6 +693,7 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
betas: tp.Optional[tp.Tuple[float, float]] = None,
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
quantize: bool = False,
checkpointing: bool = False,
device=None,
dtype=None,
**kwargs,
@ -662,6 +711,8 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
if self.positional_embedding in {"rope", "sin_rope"}:
self.rope = RotaryEmbedding(max_period=max_period)
self.checkpointing = checkpointing
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
@ -680,7 +731,7 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
if quantize:
# Quantizing layers one by one to avoid taking too much space during init.
self.layers[-1].to(device=device, dtype=dtype)
quantize_transformer(self.layers[-1])
replace_linear_with_qlinear(self.layers[-1])
def _init_streaming_state(self, batch_size: int) -> _TransformerState:
device = next(self.parameters()).device
@ -705,7 +756,16 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
x = x + self.positional_scale * pos_emb
for layer in self.layers:
x = layer(x, *args, **kwargs)
if self.checkpointing:
y = torch_checkpoint(
layer, x, *args, use_reentrant=False,
determinism_check='none',
preserve_rng_state=False,
**kwargs)
assert isinstance(y, torch.Tensor)
x = y
else:
x = layer(x, *args, **kwargs)
if state is not None:
state.offset.add_(T)

View file

@ -13,7 +13,6 @@ import tarfile
import time
import secrets
import sys
import aiohttp
from aiohttp import web
from huggingface_hub import hf_hub_download
@ -21,8 +20,6 @@ import numpy as np
import sentencepiece
import sphn
import torch
from .client_utils import log
from .models import loaders, MimiModel, LMModel, LMGen
from .run_inference import get_condition_tensors
@ -71,6 +68,7 @@ class ServerState:
if tokens is None:
continue
_ = self.mimi.decode(tokens[:, 1:])
torch.cuda.synchronize()
async def handle_chat(self, request):
@ -190,8 +188,12 @@ def main():
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
help="HF repo to look into, defaults Moshiko. "
"Use this to select a different pre-trained model.")
parser.add_argument("--lora-weight", type=str, help="Path to a local checkpoint file for LoRA.", default=None)
parser.add_argument("--config-path", type=str, help="Path to a local config file.", default=None)
parser.add_argument("--cfg-coef", type=float, default=1., help="CFG coefficient.")
parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.")
parser.add_argument("--no_fuse_lora", action="store_false", dest="fuse_lora", default=True,
help="Do not fuse LoRA layers intot Linear layers.")
parser.add_argument("--half", action="store_const", const=torch.float16, default=torch.bfloat16,
dest="dtype", help="Run inference with float16, not bfloat16, better for old GPUs.")
parser.add_argument(
@ -223,7 +225,8 @@ def main():
log("info", "retrieving checkpoint")
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer)
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer,
lora_weights=args.lora_weight, config_path=args.config_path)
log("info", "loading mimi")
mimi = checkpoint_info.get_mimi(device=args.device)
log("info", "mimi loaded")
@ -231,7 +234,7 @@ def main():
text_tokenizer = checkpoint_info.get_text_tokenizer()
log("info", "loading moshi")
lm = checkpoint_info.get_moshi(device=args.device, dtype=args.dtype)
lm = checkpoint_info.get_moshi(device=args.device, dtype=args.dtype, fuse_lora=args.fuse_lora)
log("info", "moshi loaded")
state = ServerState(checkpoint_info.model_type, mimi, text_tokenizer, lm, args.cfg_coef, args.device,

View file

@ -15,98 +15,48 @@ import torch
from torch import nn
def linear(module: nn.Module, x: torch.Tensor, name='weight') -> torch.Tensor:
import bitsandbytes as bnb # type: ignore
if is_quantized(module, name):
class QLinear(nn.Module):
def __init__(self, linear: nn.Linear):
super().__init__()
from bitsandbytes import functional as bnbF # type: ignore
weight = linear.weight
assert weight.data.dtype.is_floating_point
assert linear.bias is None
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
self.weight = nn.Parameter(CB, requires_grad=False)
self.weight_scb = nn.Parameter(SCB, requires_grad=False)
def forward(self, x):
import bitsandbytes as bnb # type: ignore
state = bnb.MatmulLtState()
state.CB = getattr(module, name)
state.CB = self.weight # type: ignore
assert isinstance(state.CB, torch.Tensor)
state.SCB = getattr(module, name + '_scb')
state.SCB = self.weight_scb # type: ignore
assert isinstance(state.SCB, torch.Tensor)
if state.SCB.dtype != torch.float:
raise RuntimeError(
"Expected `weight_scb` to have type float, but got bfloat16. "
"When using quantized models, care should be taken not to change the dtype of "
"the model once initialized.")
assert state.SCB.dtype == torch.float, state.SCB.dtype
state.has_fp16_weights = False
y = bnb.matmul(x.half(), state.CB, state=state)
assert isinstance(y, torch.Tensor)
return y
else:
return nn.functional.linear(x, getattr(module, name))
def multi_linear(num_steps: int, schedule: list[int] | None,
module: nn.Module, x: torch.Tensor, offset: int, name='weight') -> torch.Tensor:
"""Utility to apply a multi linear layer to the given input. A multi linear layer
applies a different set of weight for each time step.
Args:
num_steps (int): Number of possible time steps.
schedule (list[int] or None): schedule for weight sharing.
weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`.
x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
offset (int): offset for the current time step, in particular for decoding, with
time steps provided one by one.
"""
import bitsandbytes as bnb # type: ignore
B, T, C = x.shape
ys: list[torch.Tensor] = []
if is_quantized(module, name):
weight = getattr(module, name)
weight_scb = getattr(module, name + '_scb')
else:
weight = getattr(module, name)
weight_scb = None
assert isinstance(weight, torch.Tensor)
num_linear = num_steps
if schedule is not None:
num_linear = max(schedule) + 1
chout, chin = weight.shape
weight = weight.view(num_linear, -1, chin)
if weight_scb is not None:
assert isinstance(weight, torch.Tensor)
assert weight_scb.shape == (chout,), (weight_scb, chout)
weight_scb = weight_scb.view(num_linear, -1)
assert weight_scb.dtype == torch.float, weight_scb.dtype
for t in range(T):
linear_index = t + offset
if schedule is not None:
linear_index = schedule[linear_index]
if weight_scb is None:
y = nn.functional.linear(x[:, t], weight[linear_index])
def replace_linear_with_qlinear(module):
"""Recursively replace all Linear layers with QLinear layers."""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, QLinear(child))
elif isinstance(child, QLinear):
# Slight issue with the way we implement things: the scale param
# might get casted with the rest of the model to bfloat16, altough
# we most likely want to keep it as float. For the LM model we might call this function twice,
# first layer by layer to avoid to big of a memory usage, and second, at the end
# of the LM init, after all other modules are initialized and properly dtyped.
# In any case that should happen before loading the state dict to avoid a loss of precision.
child.float()
else:
state = bnb.MatmulLtState()
CB = weight[linear_index]
state.CB = CB # type: ignore
state.SCB = weight_scb[linear_index]
state.has_fp16_weights = False
y = bnb.matmul(x[:, t].half(), CB, state=state)
assert isinstance(y, torch.Tensor)
ys.append(y)
out = torch.stack(ys, 1)
return out
def is_quantized(module: nn.Module, name: str = 'weight'):
return hasattr(module, name + '_scb')
def quantize_param(module: nn.Module, name: str = 'weight') -> None:
from bitsandbytes import functional as bnbF # type: ignore
if is_quantized(module, name):
# Due to model casting, the type of SCB might be wrong, althought
# that would only happen during the init. Let's recast it to float.
SCB = getattr(module, name + '_scb')
if SCB.dtype != torch.float:
setattr(module, name + '_scb', nn.Parameter(SCB.to(torch.float), requires_grad=False))
return
weight = getattr(module, name)
assert weight.data.dtype.is_floating_point
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
setattr(module, name, nn.Parameter(CB, requires_grad=False))
setattr(module, name + '_scb', nn.Parameter(SCB, requires_grad=False))
def quantize_linear(linear: nn.Module) -> None:
assert linear.bias is None
quantize_param(linear)
replace_linear_with_qlinear(child)

View file

@ -0,0 +1,52 @@
import torch
from .compile import torch_compile_lazy
@torch_compile_lazy
def cross_entropy(
logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, dtype=torch.float32,
logits_soft_clip: float | None = None) -> torch.Tensor:
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
dtype (type): Data type of the output cross entropy.
logits_soft_clip (float): Clipping value for the logits to avoid numerical instability.
Recommended value: 30.0.
Returns:
ce (torch.Tensor): Cross entropy [B, K, T] with type dtype.
"""
output_shape = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
logits = logits.view(-1, logits.shape[-1])
targets = targets.reshape(-1)
mask = mask.reshape(-1)
safe_targets = torch.where(
mask,
targets,
torch.zeros(1, device=targets.device, dtype=targets.dtype),
)
# Chunking the conversion to float32 to avoid OOMs.
ce_chunks = []
for logits_chunk, targets_chunk in zip(torch.chunk(logits, 4), torch.chunk(safe_targets, 4)):
logits_chunk = logits_chunk.to(dtype)
if logits_soft_clip is not None:
logits_chunk = logits_soft_clip * torch.tanh(logits_chunk / logits_soft_clip)
log_partition = torch.logsumexp(logits_chunk, dim=-1, keepdim=True)
# For some reason, the PyTorch cross entropy is super slow with inputs with large cardinality (e.g. 32000)
# so we reimplement the cross entropy ourselves...
ce_chunks.append(log_partition - logits_chunk.gather(-1, targets_chunk[..., None]))
ce = torch.cat(ce_chunks, dim=0)
ce = ce[..., 0]
ce = torch.where(mask, ce, torch.zeros(1, device=ce.device, dtype=ce.dtype))
return ce.view(output_shape)

View file

@ -35,6 +35,7 @@ build-backend = "setuptools.build_meta"
[project.optional-dependencies]
dev = [
"pyright",
"pytest",
"flake8",
"pre-commit",
"gradio-webrtc>=0.0.18"

Binary file not shown.

Binary file not shown.

Binary file not shown.

70
moshi/tests/test_lm.py Normal file
View file

@ -0,0 +1,70 @@
from pathlib import Path
from safetensors.torch import load_file, load_model
import torch
from moshi.models import lm
from moshi.utils.utils import cross_entropy
def _get_assets() -> Path:
return Path(__file__).parent / 'assets'
def _get_lm(device=None, dtype=torch.float32) -> lm.LMModel:
torch.manual_seed(1234)
model = lm.LMModel(
delays=[0, 1, 2, 4],
n_q=3,
dep_q=3,
card=32,
text_card=48,
dim=16,
num_layers=2,
num_heads=1,
hidden_scale=1,
depformer_dim=16,
depformer_multi_linear=True,
depformer_weights_per_step=True,
depformer_weights_per_step_schedule=[0, 1, 1],
depformer_low_rank_embeddings=8,
depformer_num_heads=1,
depformer_gating='silu',
context=4,
device=device,
dtype=dtype,
)
return model
def test_init():
_get_lm(dtype=torch.float32)
_get_lm(dtype=torch.bfloat16)
_get_lm(dtype=torch.float16)
@torch.no_grad
def test_forward():
model = _get_lm()
load_model(model, _get_assets() / 'test_lm_model.safetensors')
codes = load_file(_get_assets() / 'test_lm_codes.safetensors')['codes']
out = model(codes)
assert out.logits is not None
assert out.text_logits is not None
assert out.mask.shape == codes[:, 1:].shape
assert out.text_mask.shape == codes[:, :1].shape
assert out.logits.shape[:-1] == codes[:, 1:].shape
assert out.logits.shape[-1] == model.card
assert out.text_logits.shape[-1] == model.text_card
ref_out = load_file(_get_assets() / 'test_lm_out.safetensors')
assert (ref_out['mask'] == out.mask).all()
assert (ref_out['text_mask'] == out.text_mask).all()
ce = cross_entropy(out.logits, codes[:, 1:], out.mask)
ce_ref = cross_entropy(ref_out['logits'], codes[:, 1:], out.mask)
delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2))
assert delta.amax() <= 1e-6, delta.amax()
ce = cross_entropy(out.text_logits, codes[:, :1], out.text_mask)
ce_ref = cross_entropy(ref_out['text_logits'], codes[:, :1], out.text_mask)
delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2))
assert delta.amax() <= 1e-6, delta.amax()

View file

@ -13,7 +13,7 @@ from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import save_file
from moshi.models import loaders
from moshi.modules.transformer import quantize_transformer
from moshi.utils.quantize import replace_linear_with_qlinear
def get_api():
@ -35,9 +35,9 @@ def main():
print("Downloading base model.")
info = loaders.CheckpointInfo.from_hf_repo(args.hf_repo)
print("Creating model.")
model = info.get_moshi(device='cuda')
model = info.get_moshi(fuse_lora=True, device='cuda')
print("Quantizing model.")
quantize_transformer(model)
replace_linear_with_qlinear(model)
if args.new_hf_repo is None:
new_repo = repo.rsplit('-', 1)[0] + '-q8'

142
scripts/import_mlx_lora.py Normal file
View file

@ -0,0 +1,142 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from moshi.models import loaders
from pathlib import Path
import torch
from safetensors.torch import save_file
def import_model(
tch_model,
out_path: Path,
silent: bool = False,
max_out_n_q: int | None = None,
) -> None:
in_n_q: int | None = None
for idx in range(999):
name = f"emb.{idx}.weight"
if name not in tch_model:
in_n_q = idx
break
out_n_q: int | None = None
for idx in range(999):
name = f"linears.{idx}.weight"
if name not in tch_model:
out_n_q = idx
break
assert in_n_q is not None
assert out_n_q is not None
if not silent:
print(f"in_n_q: {in_n_q}, out_n_q: {out_n_q}")
depformer_layers: int | None = None
for idx in range(999):
if f"depformer.layers.{idx}.self_attn.in_projs.0.weight" not in tch_model:
depformer_layers = idx
break
assert depformer_layers is not None
if not silent:
print(f"depformer layers: {depformer_layers}")
model = {}
for name in ["text_emb.weight", "text_linear.weight"]:
model[name] = tch_model[name]
for name in tch_model.keys():
if name.startswith("condition_provider.conditioners"):
model[name] = tch_model[name]
model["out_norm.weight"] = tch_model["out_norm.alpha"][0, 0]
for idx in range(in_n_q):
src_name = f"emb.{idx}.weight"
dst_name = f"audio_embs.{idx}.weight"
model[dst_name] = tch_model[src_name]
for k, v in sorted(tch_model.items()):
print(k, v.shape, v.dtype)
if k.startswith("transformer"):
if k.endswith(".alpha"):
v = v[0, 0]
k = k.replace(".alpha", ".weight")
k = k.replace(".in_projs.0.weight", ".in_proj.weight")
k = k.replace(".out_projs.0.weight", ".out_proj.weight")
model[k] = v
# Only export the first slices of the depformer (main).
if max_out_n_q is not None:
exported_out_n_q = min(max_out_n_q, out_n_q)
print(f"only exporting the first {exported_out_n_q} depformer layers")
else:
exported_out_n_q = out_n_q
for idx in range(exported_out_n_q):
tch_idx = idx
base = f"depformer.slices.{idx}."
model[base + "linear_in.weight"] = tch_model[f"depformer_in.{tch_idx}.weight"].clone()
model[base + "linear_out.weight"] = tch_model[f"linears.{idx}.weight"]
if idx == 0:
model[base + "emb.weight"] = tch_model["depformer_text_emb.weight"]
if "depformer_text_emb.low_rank.weight" in tch_model:
model[base + "emb.low_rank.weight"] = tch_model["depformer_text_emb.low_rank.weight"].clone()
else:
model[base + "emb.weight"] = tch_model[f"depformer_emb.{idx-1}.weight"].clone()
if f"depformer_emb.{tch_idx-1}.low_rank.weight" in tch_model:
model[base + "emb.low_rank.weight"] = tch_model[f"depformer_emb.{idx-1}.low_rank.weight"].clone()
for layer_idx in range(depformer_layers):
layer = base + f"transformer.layers.{layer_idx}."
model[layer + "self_attn.in_proj.weight"] = (
tch_model[f"depformer.layers.{layer_idx}.self_attn.in_projs.{tch_idx}.weight"]
)
model[layer + "self_attn.out_proj.weight"] = (
tch_model[f"depformer.layers.{layer_idx}.self_attn.out_projs.{tch_idx}.weight"]
)
model[layer + "norm1.weight"] = tch_model[
f"depformer.layers.{layer_idx}.norm1.alpha"
][0, 0].clone()
model[layer + "norm2.weight"] = tch_model[
f"depformer.layers.{layer_idx}.norm2.alpha"
][0, 0].clone()
model[layer + "gating.linear_in.weight"] = tch_model[
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_in.weight"
].clone()
model[layer + "gating.linear_out.weight"] = tch_model[
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_out.weight"
].clone()
save_file(model, out_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
help="HF repo to look into, defaults Moshiko. "
"Use this to select a different pre-trained model.")
parser.add_argument("--lora-weight", type=str, help="Path to a local checkpoint file for LoRA.", default=None)
parser.add_argument("--config-path", type=str, help="Path to a local config file.", default=None)
parser.add_argument(
"-s", "--silent", action="store_true", help="only prints the checkpoint name"
)
parser.add_argument(
"--max-out-n-q",
type=int,
help="limit the number of depformer layers that are exported",
)
parser.add_argument("out", type=str, help="the mlx safetensors file to generate")
args = parser.parse_args()
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer,
lora_weights=args.lora_weight, config_path=args.config_path)
lm = checkpoint_info.get_moshi(device="cpu", dtype=torch.bfloat16, fuse_lora=True)
for key, value in lm.state_dict().items():
print(key, value.shape)
out_path = Path(args.out)
import_model(lm.state_dict(), out_path, silent=args.silent, max_out_n_q=args.max_out_n_q)
if __name__ == "__main__":
main()

139
scripts/import_rust_lora.py Normal file
View file

@ -0,0 +1,139 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from moshi.models import loaders
from pathlib import Path
import torch
from safetensors.torch import save_file
def import_model(
tch_model,
max_out_n_q: int | None,
out_path: Path,
) -> None:
for k, v in sorted(tch_model.items()):
print(k, v.shape, v.dtype)
in_n_q: int | None = None
for idx in range(999):
name = f"emb.{idx}.weight"
if name not in tch_model:
in_n_q = idx
break
out_n_q: int | None = None
for idx in range(999):
name = f"linears.{idx}.weight"
if name not in tch_model:
out_n_q = idx
break
assert in_n_q is not None
assert out_n_q is not None
depformer_layers: int | None = None
for idx in range(999):
if f"depformer.layers.{idx}.self_attn.in_projs.0.weight" not in tch_model:
depformer_layers = idx
break
assert depformer_layers is not None
print(f"depformer layers: {depformer_layers}")
model = {}
for name in ["text_emb.weight", "text_linear.weight", "out_norm.alpha"]:
model[name] = tch_model[name]
for name in tch_model.keys():
if name.startswith("condition_provider.conditioners"):
model[name] = tch_model[name]
for idx in range(in_n_q):
name = f"emb.{idx}.weight"
model[name] = tch_model[name]
for k, v in sorted(tch_model.items()):
if k.startswith("transformer"):
k = k.replace(".in_projs.0.weight", ".in_proj.weight")
k = k.replace(".out_projs.0.weight", ".out_proj.weight")
model[k] = v
# Only export the first slices of the depformer (main).
if max_out_n_q is not None:
exported_out_n_q = min(max_out_n_q, out_n_q)
print(f"only exporting the first {exported_out_n_q} depformer layers")
else:
print(f"exporting all {out_n_q} depformer layers")
exported_out_n_q = out_n_q
max_df_steps = out_n_q
for idx in range(exported_out_n_q):
tch_idx = idx
base = f"depformer.{idx}."
model[base + "linear_in.weight"] = tch_model[f"depformer_in.{tch_idx}.weight"].clone()
model[base + "linear_out.weight"] = tch_model[f"linears.{idx}.weight"]
if idx == 0:
model[base + "emb.weight"] = tch_model["depformer_text_emb.weight"]
if "depformer_text_emb.low_rank.weight" in tch_model:
model[base + "emb.low_rank.weight"] = tch_model["depformer_text_emb.low_rank.weight"].clone()
else:
model[base + "emb.weight"] = tch_model[f"depformer_emb.{tch_idx-1}.weight"].clone()
if f"depformer_emb.{tch_idx-1}.low_rank.weight" in tch_model:
model[base + "emb.low_rank.weight"] = tch_model[f"depformer_emb.{tch_idx-1}.low_rank.weight"].clone()
for layer_idx in range(depformer_layers):
layer = base + f"transformer.layers.{layer_idx}."
# WARNING: note that this uses in_proj_weight vs out_proj.weight
model[layer + "self_attn.in_proj_weight"] = (
tch_model[f"depformer.layers.{layer_idx}.self_attn.in_projs.{tch_idx}.weight"]
)
model[layer + "self_attn.out_proj.weight"] = (
tch_model[f"depformer.layers.{layer_idx}.self_attn.out_projs.{tch_idx}.weight"]
)
model[layer + "norm1.alpha"] = tch_model[
f"depformer.layers.{layer_idx}.norm1.alpha"
].clone()
model[layer + "norm2.alpha"] = tch_model[
f"depformer.layers.{layer_idx}.norm2.alpha"
].clone()
model[layer + "gating.linear_in.weight"] = tch_model[
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_in.weight"
].clone()
model[layer + "gating.linear_out.weight"] = tch_model[
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_out.weight"
].clone()
save_file(model, out_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
help="HF repo to look into, defaults Moshiko. "
"Use this to select a different pre-trained model.")
parser.add_argument("--lora-weight", type=str, help="Path to a local checkpoint file for LoRA.", default=None)
parser.add_argument("--config-path", type=str, help="Path to a local config file.", default=None)
parser.add_argument(
"-s", "--silent", action="store_true", help="only prints the checkpoint name"
)
parser.add_argument(
"--max-out-n-q",
type=int,
help="limit the number of depformer layers that are exported",
)
parser.add_argument("out", type=str, help="the rust safetensors file to generate")
args = parser.parse_args()
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer,
lora_weights=args.lora_weight, config_path=args.config_path)
lm = checkpoint_info.get_moshi(device="cpu", dtype=torch.bfloat16, fuse_lora=True)
for key, value in lm.state_dict().items():
print(key, value.shape)
out_path = Path(args.out)
import_model(lm.state_dict(), out_path=out_path, max_out_n_q=args.max_out_n_q)
if __name__ == "__main__":
main()