Adding LoRA support (#249)
* adding training forward * tests * tests * fix * plop * plop * plop * lora all but not multilinear * all lora works+lazy comp * minor * minor fix * lora ckpt loaders * checkpointing * checkpointing + new lora loading * tokenizer path * quantization back (to test) * small fixes * fuse_lora arg * remove compile * edits from alex * edits from alex * trying dynamic no compile * reformat configs + loading * wip * wip * fix typing * fixing loading of old models * some improvements * remove param * plop * Lora and quant (#248) * wip * wip * fix typing * fixing loading of old models * some improvements * remove param * plop * fix * fix * fine-tune loading * Import lora models for mlx. * meta init * exchange order * Rust lora importer. * update load weight * fixing meta init * load weight in moshi * no post hook * fix * dbg --------- Co-authored-by: hippolytepilchen <hippolyte.pilchen@gmail.com> Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
5707114ca4
commit
ea5401cc3a
22 changed files with 1170 additions and 284 deletions
3
.github/actions/rust_build/action.yml
vendored
3
.github/actions/rust_build/action.yml
vendored
|
@ -31,3 +31,6 @@ runs:
|
|||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libasound2-dev
|
||||
echo "test"
|
||||
cmake --version
|
||||
apt-cache show cmake
|
||||
|
|
|
@ -13,6 +13,12 @@ repos:
|
|||
entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pyright'
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
- id: tests-moshi
|
||||
name: pytests on moshi package
|
||||
language: system
|
||||
entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pytest tests'
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
- id: flake8-moshi_mlx
|
||||
name: flake8 on moshi_mlx package
|
||||
language: system
|
||||
|
|
|
@ -3,3 +3,4 @@ include *.md
|
|||
include *.cfg
|
||||
include requirements.txt
|
||||
include moshi/py.typed
|
||||
include tests/assets/*.safetensors
|
||||
|
|
|
@ -16,4 +16,4 @@ from . import modules
|
|||
from . import quantization
|
||||
from . import utils
|
||||
|
||||
__version__ = "0.2.3"
|
||||
__version__ = "0.2.4a1"
|
||||
|
|
|
@ -16,7 +16,6 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from ..modules.transformer import create_sin_embedding
|
||||
from ..utils import quantize
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -212,7 +211,7 @@ class BaseConditioner(nn.Module, tp.Generic[Prepared]):
|
|||
cond = torch.zeros(B, T, C, device=cond.device, dtype=cond.dtype)
|
||||
mask = torch.zeros_like(cond[..., 0], dtype=torch.bool)
|
||||
|
||||
cond = quantize.linear(self.output_proj, cond)
|
||||
cond = self.output_proj(cond)
|
||||
|
||||
maskf = mask.float()[..., None]
|
||||
if self.learnt_padding is not None:
|
||||
|
|
|
@ -13,59 +13,31 @@ from dataclasses import dataclass, field
|
|||
from functools import partial
|
||||
import logging
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..conditioners import ConditionProvider, ConditionFuser, ConditionTensors
|
||||
from ..utils.sampling import sample_token
|
||||
from ..utils.compile import CUDAGraphed
|
||||
from ..utils import quantize
|
||||
from ..utils.quantize import replace_linear_with_qlinear
|
||||
from ..modules.streaming import StreamingContainer, StreamingModule, State
|
||||
from ..modules.transformer import (
|
||||
StreamingTransformer,
|
||||
quantize_transformer,
|
||||
create_norm_fn,
|
||||
)
|
||||
from ..modules.transformer import StreamingTransformer, create_norm_fn
|
||||
from .lm_utils import (_delay_sequence,
|
||||
_undelay_sequence,
|
||||
_init_layer,
|
||||
ScaledEmbedding)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Embedding):
|
||||
"""Boost learning rate for embeddings (with `scale`).
|
||||
|
||||
Args:
|
||||
norm (bool): if True, uses a layer norm after the embedding.
|
||||
zero_idx (int): special value indicating that the output should be exactly 0.
|
||||
low_rank (int | None): if provided, uses low rank embedding with a linear layer to reach
|
||||
the desired dimension. Quite efficient for reducing the number of weights for very large vocabs.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int,
|
||||
*args, norm: bool = False, zero_idx: int = -1,
|
||||
low_rank: int | None = None, **kwargs):
|
||||
super().__init__(num_embeddings, low_rank or embedding_dim, *args, **kwargs)
|
||||
self.norm = None
|
||||
if norm:
|
||||
self.norm = create_norm_fn("layer_norm", self.embedding_dim)
|
||||
assert zero_idx < 0, "Please use negative values for the zero_idx."
|
||||
self.zero_idx = zero_idx
|
||||
self.low_rank = None
|
||||
if low_rank is not None:
|
||||
self.low_rank = nn.Linear(low_rank, embedding_dim, bias=False)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
is_zero = input == self.zero_idx
|
||||
zero = torch.zeros(1, dtype=input.dtype, device=input.device)
|
||||
input = input.clamp(min=0)
|
||||
y = super().forward(input, *args, **kwargs)
|
||||
if self.norm is not None:
|
||||
y = self.norm(y)
|
||||
y = torch.where(is_zero[..., None], zero, y)
|
||||
if self.low_rank is not None:
|
||||
y = quantize.linear(self.low_rank, y)
|
||||
return y
|
||||
@dataclass
|
||||
class LMOutput:
|
||||
# The logits are already re-aligned with the input codes
|
||||
# hence no extra shift is required, e.g. when computing CE
|
||||
logits: torch.Tensor # [B, K, T, card]
|
||||
mask: torch.Tensor # [B, K, T]
|
||||
text_logits: torch.Tensor # [B, 1, T, text_card]
|
||||
text_mask: torch.Tensor # [B, 1, T]
|
||||
|
||||
|
||||
class LMModel(StreamingContainer):
|
||||
|
@ -88,8 +60,7 @@ class LMModel(StreamingContainer):
|
|||
depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim.
|
||||
depformer_weights_per_step_schedule (list[int] | None): mapping `CODEBOOK_INDEX -> WEIGHT_INDEX`, allowing
|
||||
depformer_low_rank_embeddings (int | None): if provided, uses low rank embeddings, with a linear
|
||||
existing_text_padding_id (bool): if True, will use a different token for the initial text token, and
|
||||
the text padding token.
|
||||
existing_text_padding_id (int): token to use for the padding.
|
||||
same_initial (bool): if True, uses the same initial tokens for both text and audio mode.
|
||||
**kwargs: Additional parameters for the transformer encoder.
|
||||
"""
|
||||
|
@ -114,13 +85,16 @@ class LMModel(StreamingContainer):
|
|||
depformer_weights_per_step_schedule: list[int] | None = None,
|
||||
depformer_low_rank_embeddings: int | None = None,
|
||||
depformer_pos_emb: str = "sin",
|
||||
existing_text_padding_id: tp.Optional[int] = None,
|
||||
existing_text_padding_id: int = 3,
|
||||
existing_text_end_padding_id: int = 0,
|
||||
context: tp.Optional[int] = None,
|
||||
causal: bool = True,
|
||||
condition_provider: tp.Optional[ConditionProvider] = None,
|
||||
fuser: tp.Optional[ConditionFuser] = None,
|
||||
quantize: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -132,11 +106,11 @@ class LMModel(StreamingContainer):
|
|||
self.delays = delays
|
||||
self.dim = dim
|
||||
self.existing_text_padding_id = existing_text_padding_id
|
||||
self.existing_text_end_padding_id = existing_text_end_padding_id
|
||||
self.context = context
|
||||
self.depformer_weights_per_step_schedule = depformer_weights_per_step_schedule
|
||||
if depformer_weights_per_step_schedule is not None:
|
||||
assert len(depformer_weights_per_step_schedule) == dep_q
|
||||
kwargs["context"] = context
|
||||
EmbeddingFactory = partial(
|
||||
ScaledEmbedding,
|
||||
norm=norm_emb,
|
||||
|
@ -147,11 +121,10 @@ class LMModel(StreamingContainer):
|
|||
self.emb = nn.ModuleList(
|
||||
[EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)]
|
||||
)
|
||||
# Text card + padding token (if not in the original tokenizer)
|
||||
extra_text = self.existing_text_padding_id is None
|
||||
# Unlike for audio, here we authorize the model to output the special token.
|
||||
self.text_emb = EmbeddingFactory(text_card + 1, dim)
|
||||
self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj)
|
||||
|
||||
self.text_linear = nn.Linear(dim, text_card, bias=bias_proj)
|
||||
depformer_prefix = "depformer_"
|
||||
main_kwargs = {
|
||||
k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix)
|
||||
|
@ -164,6 +137,9 @@ class LMModel(StreamingContainer):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
quantize=quantize,
|
||||
context=context,
|
||||
causal=causal,
|
||||
checkpointing=gradient_checkpointing,
|
||||
**main_kwargs,
|
||||
)
|
||||
self.out_norm = create_norm_fn(norm, dim)
|
||||
|
@ -205,7 +181,9 @@ class LMModel(StreamingContainer):
|
|||
dim_feedforward=depformer_dim_feedforward,
|
||||
norm=norm,
|
||||
weights_per_step_schedule=depformer_weights_per_step_schedule,
|
||||
causal=causal,
|
||||
quantize=quantize,
|
||||
checkpointing=gradient_checkpointing,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
**kwargs_dep,
|
||||
|
@ -221,8 +199,9 @@ class LMModel(StreamingContainer):
|
|||
self.condition_provider = condition_provider
|
||||
self.fuser = fuser
|
||||
self.to(device=device, dtype=dtype)
|
||||
self._init_weights()
|
||||
if quantize:
|
||||
quantize_transformer(self)
|
||||
replace_linear_with_qlinear(self)
|
||||
|
||||
@property
|
||||
def initial_token_id(self) -> int:
|
||||
|
@ -237,15 +216,12 @@ class LMModel(StreamingContainer):
|
|||
@property
|
||||
def text_padding_token_id(self) -> int:
|
||||
"""Token id for text padding."""
|
||||
if self.existing_text_padding_id is None:
|
||||
return self.text_card
|
||||
else:
|
||||
return self.existing_text_padding_id
|
||||
return self.existing_text_padding_id
|
||||
|
||||
@property
|
||||
def end_of_text_padding_id(self) -> int:
|
||||
"""Token id for optionally marking the last padding step for a word."""
|
||||
return 0
|
||||
return self.existing_text_end_padding_id
|
||||
|
||||
@property
|
||||
def zero_token_id(self) -> int:
|
||||
|
@ -294,6 +270,61 @@ class LMModel(StreamingContainer):
|
|||
token = torch.cat([text_token, audio_token], dim=1)
|
||||
return token
|
||||
|
||||
def forward(
|
||||
self, codes: torch.Tensor,
|
||||
condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
|
||||
"""Given an input tensor of codes [B, K, T] and list of conditions, returns the logits
|
||||
along with masks indicating the valid positions at which to compute the loss.
|
||||
The logits time steps are aligned with those in the input `code`.
|
||||
Should only be used for training, not inference (use `LMGen` for that).
|
||||
|
||||
Args:
|
||||
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
||||
K the number of codebooks and T the number of timesteps. When text is supported,
|
||||
the first 'codebook' corresponds to the text, and the remaining codebooks are for the audio.
|
||||
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning tensors.
|
||||
Returns:
|
||||
LMOutput: Language model outputs, containing either text or audio logits, or both.
|
||||
logits (torch.Tensor, or None) of shape [B, K, T, card] corresponding to the provided codes,
|
||||
i.e. the first item corresponds to logits to predict the first code, meaning that
|
||||
no additional shifting of codes and logits is required.
|
||||
mask (torch.Tensor, or None) of shape [B, K, T], mask over valid and invalid positions.
|
||||
Given the specified interleaving strategies, parts of the logits and codes should
|
||||
not be considered as valid predictions because of invalid context.
|
||||
text_logits (torch.Tensor, or None) of shape [B, 1, T, text_card].
|
||||
text_mask (torch.Tensor, or None) of shape [B, 1, T], mask over the valid positions for the text.
|
||||
"""
|
||||
B, K, T = codes.shape
|
||||
assert K == self.num_codebooks, (K, self.num_codebooks)
|
||||
# Delaying codes and removing the last time step that will never be an input.
|
||||
initial = self._get_initial_token().expand(B, -1, -1)
|
||||
delayed_codes = _delay_sequence(self.delays, codes, initial)
|
||||
# Inserting the empty tokens for the first time step.
|
||||
delayed_codes = torch.cat([initial, delayed_codes], dim=2)
|
||||
|
||||
sum_condition: torch.Tensor | None = None
|
||||
if condition_tensors is None:
|
||||
assert self.fuser is None
|
||||
else:
|
||||
assert self.fuser is not None
|
||||
sum_condition = self.fuser.get_sum(condition_tensors)
|
||||
|
||||
transformer_out, text_logits = self.forward_text(delayed_codes[:, :, :-1], sum_condition)
|
||||
assert transformer_out.shape[0] == delayed_codes.shape[0]
|
||||
assert transformer_out.shape[1] == delayed_codes.shape[2] - 1
|
||||
logits = self.forward_depformer_training(delayed_codes[:, :, 1:], transformer_out)
|
||||
|
||||
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
||||
# and provide the corresponding mask over invalid positions of tokens. We will with NaN values invalid positions
|
||||
# to ensure they properly handled.
|
||||
logits, logits_mask = _undelay_sequence(
|
||||
self.delays[self.audio_offset:self.audio_offset + self.dep_q],
|
||||
logits, fill_value=float('NaN'))
|
||||
logits_mask &= (codes[:, self.audio_offset: self.audio_offset + self.dep_q] != self.zero_token_id)
|
||||
text_logits, text_logits_mask = _undelay_sequence(self.delays[:1], text_logits, fill_value=float('NaN'))
|
||||
text_logits_mask &= (codes[:, :1] != self.zero_token_id)
|
||||
return LMOutput(logits, logits_mask, text_logits, text_logits_mask)
|
||||
|
||||
def forward_text(
|
||||
self,
|
||||
sequence: torch.Tensor, sum_condition: torch.Tensor | None = None
|
||||
|
@ -310,18 +341,54 @@ class LMModel(StreamingContainer):
|
|||
)
|
||||
input_ = audio_emb if input_ is None else input_ + audio_emb
|
||||
text_emb = self.text_emb(input_sequence[:, 0])
|
||||
|
||||
input_ = text_emb if input_ is None else input_ + text_emb
|
||||
if sum_condition is not None:
|
||||
input_ = input_ + sum_condition.to(input_)
|
||||
transformer_out = self.transformer(input_)
|
||||
|
||||
if self.out_norm:
|
||||
transformer_out = self.out_norm(transformer_out)
|
||||
assert isinstance(transformer_out, torch.Tensor)
|
||||
text_logits = quantize.linear(self.text_linear, transformer_out)
|
||||
text_logits = self.text_linear(transformer_out)
|
||||
text_logits = text_logits[:, None]
|
||||
return transformer_out, text_logits
|
||||
|
||||
def forward_depformer_training(
|
||||
self,
|
||||
sequence: torch.Tensor,
|
||||
transformer_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, K, T = sequence.shape
|
||||
Ka = self.dep_q
|
||||
assert (
|
||||
K == self.num_codebooks
|
||||
), f"Codebooks for Depformer training should be passed all at once, got {K,}."
|
||||
depformer_inputs = []
|
||||
for cb_index in range(Ka):
|
||||
if self.depformer_multi_linear:
|
||||
linear_index = cb_index
|
||||
if self.depformer_weights_per_step_schedule is not None:
|
||||
linear_index = self.depformer_weights_per_step_schedule[cb_index]
|
||||
transformer_in = self.depformer_in[linear_index](transformer_out)
|
||||
else:
|
||||
transformer_in = self.depformer_in[0](transformer_out)
|
||||
if cb_index == 0:
|
||||
token_in = self.depformer_text_emb(sequence[:, 0])
|
||||
else:
|
||||
token_in = self.depformer_emb[cb_index - 1](sequence[:, cb_index + self.audio_offset - 1])
|
||||
depformer_inputs.append(token_in + transformer_in)
|
||||
depformer_input = torch.stack(depformer_inputs, 2)
|
||||
# depformer_input is [B, T, K, depformer_dim], reshaping to [B * T, K, D]
|
||||
depformer_input = depformer_input.view(B * T, Ka, -1)
|
||||
depformer_output = self.depformer(depformer_input)
|
||||
all_logits = []
|
||||
for cb_index in range(Ka):
|
||||
logits = self.linears[cb_index](depformer_output[:, cb_index])
|
||||
all_logits.append(logits.view(B, T, -1))
|
||||
logits = torch.stack(all_logits, 1)
|
||||
assert logits.dim() == 4, logits.shape # [B, Ka, T, card]
|
||||
return logits
|
||||
|
||||
def forward_depformer(
|
||||
self,
|
||||
depformer_cb_index: int,
|
||||
|
@ -344,9 +411,9 @@ class LMModel(StreamingContainer):
|
|||
in_index = depformer_cb_index
|
||||
if self.depformer_weights_per_step_schedule is not None:
|
||||
in_index = self.depformer_weights_per_step_schedule[in_index]
|
||||
depformer_input = quantize.linear(self.depformer_in[in_index], depformer_input)
|
||||
depformer_input = self.depformer_in[in_index](depformer_input)
|
||||
else:
|
||||
depformer_input = quantize.linear(self.depformer_in[0], depformer_input)
|
||||
depformer_input = self.depformer_in[0](depformer_input)
|
||||
if depformer_cb_index == 0:
|
||||
last_token_input = self.depformer_text_emb(sequence[:, 0])
|
||||
else:
|
||||
|
@ -359,11 +426,35 @@ class LMModel(StreamingContainer):
|
|||
# depformer_input is [B, 1, depformer_dim].
|
||||
# The streaming state of the depformer ensures that the proper layer is run.
|
||||
dep_output = self.depformer(depformer_input)
|
||||
logits = quantize.linear(self.linears[depformer_cb_index], dep_output)
|
||||
logits = self.linears[depformer_cb_index](dep_output)
|
||||
logits = logits[:, None]
|
||||
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
|
||||
return logits
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialization of the transformer module weights.
|
||||
Mostly truncated gaussian, with `std = 1 / sqrt(dim_in)`.
|
||||
Embeddings are also initialized with `1 / sqrt(dim)` rather than `1`.
|
||||
Some layers are not going to be properly initialized:
|
||||
- in_proj in MHA.
|
||||
- depth transformer layers.
|
||||
This is to match how our models were trained so far.
|
||||
"""
|
||||
|
||||
for emb_layer in self.emb:
|
||||
_init_layer(emb_layer)
|
||||
for emb_layer in self.depformer_emb:
|
||||
_init_layer(emb_layer)
|
||||
_init_layer(self.text_emb)
|
||||
_init_layer(self.depformer_text_emb)
|
||||
_init_layer(self.text_linear)
|
||||
|
||||
for tr_layer in self.transformer.layers:
|
||||
tr_layer.apply(_init_layer)
|
||||
|
||||
for linear in self.linears:
|
||||
_init_layer(linear)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LMGenState(State):
|
||||
|
@ -482,7 +573,7 @@ class LMGen(StreamingModule[_LMGenState]):
|
|||
k = lm_model.dep_q + 1 + q_other
|
||||
delay = lm_model.delays[k]
|
||||
write_position = (state.offset + delay) % CT
|
||||
state.cache[:, k, write_position : write_position + 1] = input_tokens[
|
||||
state.cache[:, k, write_position: write_position + 1] = input_tokens[
|
||||
:, q_other
|
||||
]
|
||||
|
||||
|
@ -492,7 +583,7 @@ class LMGen(StreamingModule[_LMGenState]):
|
|||
# token that are delayed, and thus have no good value to take.
|
||||
if state.offset <= delay:
|
||||
state.cache[:, k, position] = state.initial[:, k, 0]
|
||||
input_ = state.cache[:, :, position : position + 1]
|
||||
input_ = state.cache[:, :, position: position + 1]
|
||||
|
||||
if self.check:
|
||||
# Check that we are not feeding in any value that is not generated yet.
|
||||
|
@ -526,7 +617,7 @@ class LMGen(StreamingModule[_LMGenState]):
|
|||
state.offset += 1
|
||||
position = state.offset % CT
|
||||
state.cache[:, 0, position] = text_token
|
||||
state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens
|
||||
state.cache[:, 1:lm_model.dep_q + 1, position] = audio_tokens
|
||||
|
||||
if state.offset <= self.max_delay:
|
||||
return None
|
||||
|
|
107
moshi/moshi/models/lm_utils.py
Normal file
107
moshi/moshi/models/lm_utils.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import math
|
||||
import typing as tp
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..modules.transformer import create_norm_fn
|
||||
|
||||
|
||||
def _delay_sequence(delays: tp.List[int], tensor: torch.Tensor, padding: torch.Tensor) -> torch.Tensor:
|
||||
B, K, T = tensor.shape
|
||||
assert len(delays) == K, (len(delays), K)
|
||||
outs = []
|
||||
|
||||
for k, delay in enumerate(delays):
|
||||
assert delay >= 0
|
||||
line = tensor[:, k].roll(delay, dims=1)
|
||||
if delay > 0:
|
||||
line[:, :delay] = padding[:, k]
|
||||
outs.append(line)
|
||||
return torch.stack(outs, dim=1)
|
||||
|
||||
|
||||
def _undelay_sequence(delays: tp.List[int], tensor: torch.Tensor,
|
||||
fill_value: tp.Union[int, float] = float('NaN')) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||
B, K, T, *_ = tensor.shape
|
||||
assert len(delays) == K
|
||||
mask = torch.ones(B, K, T, dtype=torch.bool, device=tensor.device)
|
||||
outs = []
|
||||
if all([delay == 0 for delay in delays]):
|
||||
return tensor, mask
|
||||
for k, delay in enumerate(delays):
|
||||
assert delay >= 0
|
||||
line = tensor[:, k].roll(-delay, dims=1)
|
||||
if delay > 0:
|
||||
line[:, -delay:] = fill_value
|
||||
mask[:, k, -delay:] = 0
|
||||
outs.append(line)
|
||||
return torch.stack(outs, dim=1), mask
|
||||
|
||||
|
||||
def _get_init_fn(input_dim: int) -> tp.Callable[[torch.Tensor], None]:
|
||||
def _init(x: torch.Tensor) -> None:
|
||||
std = 1 / math.sqrt(input_dim)
|
||||
x_orig = x
|
||||
if x.device.type == 'cpu' and x.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.float()
|
||||
|
||||
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
||||
if x_orig is not x:
|
||||
x_orig.data[:] = x.to(x_orig)
|
||||
return _init
|
||||
|
||||
|
||||
def _init_layer(m: nn.Module,
|
||||
zero_bias_init: bool = True):
|
||||
if isinstance(m, nn.Linear):
|
||||
init_fn = _get_init_fn(m.in_features)
|
||||
init_fn(m.weight)
|
||||
if zero_bias_init and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Embedding):
|
||||
init_fn = _get_init_fn(m.embedding_dim)
|
||||
init_fn(m.weight)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Embedding):
|
||||
"""Boost learning rate for embeddings (with `scale`).
|
||||
|
||||
Args:
|
||||
norm (bool): if True, uses a layer norm after the embedding.
|
||||
zero_idx (int): special value indicating that the output should be exactly 0.
|
||||
low_rank (int | None): if provided, uses low rank embedding with a linear layer to reach
|
||||
the desired dimension. Quite efficient for reducing the number of weights for very large vocabs.
|
||||
lr (float or None): learning rate to use, only valid if the `make_optim_group()` method is used.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int,
|
||||
*args, norm: bool = False, zero_idx: int = -1,
|
||||
low_rank: int | None = None, lr: float | None = None, **kwargs):
|
||||
super().__init__(num_embeddings, low_rank or embedding_dim, *args, **kwargs)
|
||||
self.norm = None
|
||||
if norm:
|
||||
self.norm = create_norm_fn("layer_norm", self.embedding_dim)
|
||||
assert zero_idx < 0, "Please use negative values for the zero_idx."
|
||||
self.zero_idx = zero_idx
|
||||
self.lr = lr
|
||||
self.low_rank = None
|
||||
if low_rank is not None:
|
||||
self.low_rank = nn.Linear(low_rank, embedding_dim, bias=False)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
is_zero = input == self.zero_idx
|
||||
zero = torch.zeros(1, dtype=input.dtype, device=input.device)
|
||||
input = input.clamp(min=0)
|
||||
y = super().forward(input, *args, **kwargs)
|
||||
if self.norm is not None:
|
||||
y = self.norm(y)
|
||||
y = torch.where(is_zero[..., None], zero, y)
|
||||
if self.low_rank is not None:
|
||||
y = self.low_rank(y)
|
||||
return y
|
||||
|
||||
def make_optim_group(self) -> dict:
|
||||
group: dict[str, tp.Any] = {"params": list(self.parameters())}
|
||||
if self.lr is not None:
|
||||
group["lr"] = self.lr
|
||||
return group
|
|
@ -2,36 +2,37 @@
|
|||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Retrieves the pretrained models for Moshi and Mimi."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
try:
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
except ImportError:
|
||||
from huggingface_hub.utils import EntryNotFoundError # pyright: ignore
|
||||
from safetensors.torch import load_model
|
||||
from safetensors.torch import load_model, load_file
|
||||
import sentencepiece
|
||||
import torch
|
||||
import typing as tp
|
||||
|
||||
from .compression import MimiModel
|
||||
from ..conditioners import BaseConditioner, ConditionProvider, ConditionFuser
|
||||
from .lm import LMModel
|
||||
from ..modules import SEANetEncoder, SEANetDecoder, transformer
|
||||
from ..quantization import SplitResidualVectorQuantizer
|
||||
from ..modules.lora import replace_all_linear_with_lora, replace_lora_with_linear
|
||||
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
FRAME_RATE = 12.5
|
||||
|
||||
TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model'
|
||||
MOSHI_NAME = 'model.safetensors'
|
||||
MOSHI_Q8_NAME = 'model.q8.safetensors'
|
||||
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
|
||||
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
|
||||
TEXT_TOKENIZER_NAME = "tokenizer_spm_32k_3.model"
|
||||
MOSHI_NAME = "model.safetensors"
|
||||
MOSHI_Q8_NAME = "model.q8.safetensors"
|
||||
MIMI_NAME = "tokenizer-e351c8d8-checkpoint125.safetensors"
|
||||
DEFAULT_REPO = "kyutai/moshiko-pytorch-bf16"
|
||||
|
||||
|
||||
_seanet_kwargs = {
|
||||
|
@ -99,7 +100,6 @@ _lm_kwargs = {
|
|||
"depformer_dim_feedforward": int(4.125 * 1024),
|
||||
"depformer_num_heads": 16,
|
||||
"depformer_num_layers": 6,
|
||||
"depformer_causal": True,
|
||||
"depformer_layer_scale": None,
|
||||
"depformer_multi_linear": True,
|
||||
"depformer_context": 8,
|
||||
|
@ -139,20 +139,25 @@ class CheckpointInfo:
|
|||
raw_config: raw config, including original keys not intended for the LM.
|
||||
model_type: indicate the intended use, should be `moshi` or `hibiki`.
|
||||
"""
|
||||
|
||||
moshi_weights: Path
|
||||
mimi_weights: Path
|
||||
tokenizer: Path
|
||||
lm_config: dict | None = None
|
||||
raw_config: dict | None = None
|
||||
model_type: str = 'moshi'
|
||||
model_type: str = "moshi"
|
||||
lora_weights: Path | None = None
|
||||
lm_gen_config: dict = field(default_factory=dict)
|
||||
|
||||
@staticmethod
|
||||
def from_hf_repo(hf_repo: str,
|
||||
moshi_weights: Path | str | None = None,
|
||||
mimi_weights: Path | str | None = None,
|
||||
tokenizer: Path | str | None = None,
|
||||
config_path: Path | str | None = None) -> 'CheckpointInfo':
|
||||
def from_hf_repo(
|
||||
hf_repo: str,
|
||||
moshi_weights: Path | str | None = None,
|
||||
mimi_weights: Path | str | None = None,
|
||||
tokenizer: Path | str | None = None,
|
||||
config_path: Path | str | None = None,
|
||||
lora_weights: Path | str | None = None,
|
||||
) -> "CheckpointInfo":
|
||||
"""Downloads the checkpoints from the given repo, along with its config.
|
||||
|
||||
Extra overrides are possible for each of Moshi, Mimi, or the text tokenizer,
|
||||
|
@ -163,28 +168,32 @@ class CheckpointInfo:
|
|||
"""
|
||||
if config_path is None:
|
||||
try:
|
||||
config_path = hf_hub_download(hf_repo, 'config.json')
|
||||
config_path = hf_hub_download(hf_repo, "config.json")
|
||||
except EntryNotFoundError:
|
||||
# No config.json, which might indicate legacy repository.
|
||||
warnings.warn(f"Repository {hf_repo} contains no config.json. "
|
||||
"Assuming this is a Moshi 7B. Support for such repository "
|
||||
"might be removed in the future.")
|
||||
warnings.warn(
|
||||
f"Repository {hf_repo} contains no config.json. "
|
||||
"Assuming this is a Moshi 7B. Support for such repository "
|
||||
"might be removed in the future."
|
||||
)
|
||||
if config_path is None:
|
||||
moshi_name = MOSHI_NAME
|
||||
mimi_name = MIMI_NAME
|
||||
tokenizer_name = TEXT_TOKENIZER_NAME
|
||||
lm_config = None
|
||||
raw_config = None
|
||||
model_type = 'moshi'
|
||||
model_type = "moshi"
|
||||
lm_gen_config = {}
|
||||
lora_name = None
|
||||
else:
|
||||
raw_config = json.loads(Path(config_path).read_text())
|
||||
lm_config = dict(raw_config)
|
||||
moshi_name = lm_config.pop('moshi_name', MOSHI_NAME)
|
||||
mimi_name = lm_config.pop('mimi_name', MIMI_NAME)
|
||||
tokenizer_name = lm_config.pop('tokenizer_name', TEXT_TOKENIZER_NAME)
|
||||
model_type = lm_config.pop('model_type', 'moshi')
|
||||
lm_gen_config = lm_config.pop('lm_gen_config', {})
|
||||
moshi_name = lm_config.pop("moshi_name", MOSHI_NAME)
|
||||
mimi_name = lm_config.pop("mimi_name", MIMI_NAME)
|
||||
tokenizer_name = lm_config.pop("tokenizer_name", TEXT_TOKENIZER_NAME)
|
||||
lora_name = lm_config.pop("lora_name", None)
|
||||
model_type = lm_config.pop("model_type", "moshi")
|
||||
lm_gen_config = lm_config.pop("lm_gen_config", {})
|
||||
|
||||
if moshi_weights is None:
|
||||
moshi_weights_final = hf_get(moshi_name, hf_repo)
|
||||
|
@ -201,23 +210,47 @@ class CheckpointInfo:
|
|||
else:
|
||||
tokenizer_final = hf_get(tokenizer)
|
||||
|
||||
return CheckpointInfo(
|
||||
moshi_weights_final, mimi_weights_final, tokenizer_final,
|
||||
lm_config, raw_config, model_type, lm_gen_config)
|
||||
if lora_weights is None and lora_name:
|
||||
lora_weights_final = hf_get(lora_name, hf_repo)
|
||||
elif lora_weights is not None:
|
||||
lora_weights_final = hf_get(lora_weights)
|
||||
else:
|
||||
lora_weights_final = None
|
||||
|
||||
def get_mimi(self, device: torch.device | str = 'cpu') -> MimiModel:
|
||||
return CheckpointInfo(
|
||||
moshi_weights_final,
|
||||
mimi_weights_final,
|
||||
tokenizer_final,
|
||||
lm_config,
|
||||
raw_config,
|
||||
model_type,
|
||||
lora_weights_final,
|
||||
lm_gen_config=lm_gen_config,
|
||||
)
|
||||
|
||||
def get_mimi(self, device: torch.device | str = "cpu") -> MimiModel:
|
||||
if self.lm_config is None:
|
||||
num_codebooks = 8
|
||||
else:
|
||||
num_codebooks = self.lm_config['dep_q']
|
||||
num_codebooks = self.lm_config["dep_q"]
|
||||
return get_mimi(self.mimi_weights, num_codebooks=num_codebooks, device=device)
|
||||
|
||||
def get_moshi(self, strict: bool = True, device: torch.device | str = 'cpu',
|
||||
dtype: torch.dtype = torch.bfloat16) -> LMModel:
|
||||
def get_moshi(
|
||||
self,
|
||||
device: torch.device | str = "cpu",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
load_weight: bool = True,
|
||||
**kwargs,
|
||||
) -> LMModel:
|
||||
model = get_moshi_lm(
|
||||
self.moshi_weights, lm_kwargs=self.lm_config,
|
||||
device=device, dtype=dtype, strict=strict)
|
||||
if self.model_type == 'hibiki':
|
||||
self.moshi_weights if load_weight else None,
|
||||
lm_kwargs=self.lm_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
lora_weights=self.lora_weights,
|
||||
**kwargs,
|
||||
)
|
||||
if self.model_type == "hibiki":
|
||||
# Sometime the model samples the EOS (2) too early, which we want to ignore.
|
||||
# We keep generating if the input file is not finished, and this is a way
|
||||
# to implicitely replace early EOS with PAD.
|
||||
|
@ -232,9 +265,9 @@ def _is_safetensors(path: Path | str) -> bool:
|
|||
return Path(path).suffix in (".safetensors", ".sft", ".sfts")
|
||||
|
||||
|
||||
def get_mimi(filename: str | Path,
|
||||
device: torch.device | str = 'cpu',
|
||||
num_codebooks: int = 8) -> MimiModel:
|
||||
def get_mimi(
|
||||
filename: str | Path, device: torch.device | str = "cpu", num_codebooks: int = 8
|
||||
) -> MimiModel:
|
||||
"""Return a pretrained Mimi model."""
|
||||
encoder = SEANetEncoder(**_seanet_kwargs)
|
||||
decoder = SEANetDecoder(**_seanet_kwargs)
|
||||
|
@ -262,7 +295,7 @@ def get_mimi(filename: str | Path,
|
|||
).to(device=device)
|
||||
model.eval()
|
||||
if _is_safetensors(filename):
|
||||
load_model(model, filename)
|
||||
load_model(model, filename, device=str(device))
|
||||
else:
|
||||
pkg = torch.load(filename, "cpu")
|
||||
model.load_state_dict(pkg["model"])
|
||||
|
@ -270,46 +303,99 @@ def get_mimi(filename: str | Path,
|
|||
return model
|
||||
|
||||
|
||||
def get_moshi_lm(filename: str | Path,
|
||||
lm_kwargs: tp.Optional[tp.Dict] = None,
|
||||
device: torch.device | str = 'cpu',
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
strict: bool = True) -> LMModel:
|
||||
def get_moshi_lm(
|
||||
filename: str | Path | None,
|
||||
lm_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
|
||||
device: torch.device | str = "cpu",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
lora_weights: str | Path | None = None,
|
||||
fuse_lora: bool = False,
|
||||
lm_kwargs_overrides={},
|
||||
) -> LMModel:
|
||||
if lm_kwargs is None:
|
||||
lm_kwargs = _lm_kwargs
|
||||
lm_kwargs = dict(lm_kwargs)
|
||||
assert lm_kwargs is not None
|
||||
|
||||
if "conditioners" in lm_kwargs:
|
||||
lm_kwargs["condition_provider"] = get_conditioner_provider(lm_kwargs["dim"], device, lm_kwargs)
|
||||
lm_kwargs["condition_provider"] = get_conditioner_provider(
|
||||
lm_kwargs["dim"], device, lm_kwargs
|
||||
)
|
||||
del lm_kwargs["conditioners"]
|
||||
if "fuser" in lm_kwargs:
|
||||
lm_kwargs["fuser"] = get_condition_fuser(lm_kwargs)
|
||||
|
||||
lm_kwargs = lm_kwargs | lm_kwargs_overrides
|
||||
assert lm_kwargs is not None
|
||||
|
||||
# deprecated params.
|
||||
lm_kwargs.pop("depformer_causal", None)
|
||||
|
||||
# lora params.
|
||||
lora = lm_kwargs.pop("lora", False)
|
||||
lora_rank = lm_kwargs.pop("lora_rank", 128)
|
||||
lora_scaling = lm_kwargs.pop("lora_scaling", 2.0)
|
||||
|
||||
init_device = device
|
||||
if filename is not None:
|
||||
init_device = torch.device('meta')
|
||||
|
||||
model = LMModel(
|
||||
device=device,
|
||||
device=init_device,
|
||||
dtype=dtype,
|
||||
**lm_kwargs)
|
||||
model.eval()
|
||||
if _is_safetensors(filename):
|
||||
load_model(model, filename, strict=strict)
|
||||
else:
|
||||
pkg = torch.load(
|
||||
filename,
|
||||
"cpu",
|
||||
|
||||
if filename is not None:
|
||||
if _is_safetensors(filename):
|
||||
state = load_file(filename, device=str(device))
|
||||
for key, value in state.items():
|
||||
if value.dtype.is_floating_point:
|
||||
value = value.to(dtype=dtype)
|
||||
state[key] = value
|
||||
model.load_state_dict(state, assign=True)
|
||||
|
||||
else:
|
||||
pkg = torch.load(filename, "cpu",)
|
||||
model.load_state_dict(pkg["fsdp_best_state"]["model"], assign=True)
|
||||
|
||||
if lora:
|
||||
assert not lm_kwargs.get("quantize"), (
|
||||
"LoRA and quantization are incompatible for now."
|
||||
)
|
||||
model.load_state_dict(pkg["fsdp_best_state"]["model"])
|
||||
model = get_lora_moshi(
|
||||
model=model,
|
||||
lora_rank=lora_rank,
|
||||
lora_scaling=lora_scaling,
|
||||
lora_weights=lora_weights,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
fuse_lora=fuse_lora,
|
||||
)
|
||||
else:
|
||||
assert lora_weights is None, (
|
||||
"`lora` is False, but received some lora_weights to load."
|
||||
)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioner(output_dim: int, device: torch.device | str, conditioner_cfg: dict) -> BaseConditioner:
|
||||
def get_conditioner(
|
||||
output_dim: int, device: torch.device | str, conditioner_cfg: dict
|
||||
) -> BaseConditioner:
|
||||
conditioner_type = conditioner_cfg["type"]
|
||||
conditioner_kwargs = conditioner_cfg[conditioner_type]
|
||||
conditioner_kwargs.update({'output_dim': output_dim, 'device': device})
|
||||
if conditioner_type == 'lut':
|
||||
conditioner_kwargs.update({"output_dim": output_dim, "device": device})
|
||||
if conditioner_type == "lut":
|
||||
from ..conditioners.text import LUTConditioner
|
||||
|
||||
return LUTConditioner(**conditioner_kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknow conditioner type {conditioner_type}.")
|
||||
|
||||
|
||||
def get_conditioner_provider(output_dim: int, device: torch.device | str, cfg: dict) -> ConditionProvider:
|
||||
def get_conditioner_provider(
|
||||
output_dim: int, device: torch.device | str, cfg: dict
|
||||
) -> ConditionProvider:
|
||||
"""Instantiate a conditioning model."""
|
||||
conditioners: tp.Dict[str, BaseConditioner] = {}
|
||||
for cond, cond_cfg in cfg["conditioners"].items():
|
||||
|
@ -321,8 +407,39 @@ def get_conditioner_provider(output_dim: int, device: torch.device | str, cfg: d
|
|||
def get_condition_fuser(cfg: dict) -> ConditionFuser:
|
||||
"""Instantiate a condition fuser object."""
|
||||
fuser_cfg = cfg["fuser"]
|
||||
fuser_methods = ['sum', 'cross', 'prepend']
|
||||
fuser_methods = ["sum", "cross", "prepend"]
|
||||
fuse2cond = {k: fuser_cfg.get(k, []) for k in fuser_methods}
|
||||
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
||||
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
||||
return fuser
|
||||
|
||||
|
||||
def get_lora_moshi(
|
||||
model: LMModel,
|
||||
lora_weights: str | Path | None,
|
||||
lora_rank: int,
|
||||
lora_scaling: float,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: torch.device | str = "cpu",
|
||||
fuse_lora: bool = True,
|
||||
) -> LMModel:
|
||||
init_device = device
|
||||
if lora_weights is not None:
|
||||
init_device = torch.device('meta')
|
||||
replace_all_linear_with_lora(model, lora_rank, lora_scaling, device=init_device)
|
||||
if lora_weights is not None:
|
||||
assert _is_safetensors(lora_weights), "LoRA weights must be a safetensors file."
|
||||
lora_state_dict = load_file(lora_weights, device=str(device))
|
||||
for key, value in lora_state_dict.items():
|
||||
if value.dtype.is_floating_point:
|
||||
value = value.to(dtype=dtype)
|
||||
lora_state_dict[key] = value
|
||||
res = model.load_state_dict(lora_state_dict, strict=False, assign=True)
|
||||
if res.unexpected_keys:
|
||||
raise RuntimeError(
|
||||
f"unexpected_keys in the lora weights: {res.unexpected_keys}"
|
||||
)
|
||||
model = model.to(dtype=dtype, device=device)
|
||||
if fuse_lora:
|
||||
replace_lora_with_linear(model)
|
||||
return model
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from contextlib import ExitStack
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..utils.compile import torch_compile_lazy
|
||||
from ..utils import quantize
|
||||
from ..utils.compile import torch_compile_lazy, no_compile
|
||||
|
||||
|
||||
@torch_compile_lazy
|
||||
|
@ -22,6 +22,20 @@ def gating_forward_kernel(
|
|||
return x
|
||||
|
||||
|
||||
def gating_forward_generic(
|
||||
linear_in: nn.Module,
|
||||
linear_out: nn.Module,
|
||||
activation,
|
||||
x: torch.Tensor
|
||||
):
|
||||
x = linear_in(x)
|
||||
B, T, _ = x.shape
|
||||
x = x.view(B, T, 2, -1)
|
||||
x = activation(x[..., 0, :]) * x[..., 1, :]
|
||||
x = linear_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class ActivationGating(nn.Module):
|
||||
"""
|
||||
Gating FFN layer, using the given activation.
|
||||
|
@ -42,23 +56,30 @@ class ActivationGating(nn.Module):
|
|||
hidden = (21 * dim) // 8
|
||||
else:
|
||||
hidden = (2 * dim_feedforward) // 3
|
||||
|
||||
self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
|
||||
self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
|
||||
|
||||
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
||||
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if quantize.is_quantized(self.linear_in):
|
||||
assert quantize.is_quantized(self.linear_out)
|
||||
x = quantize.linear(self.linear_in, x)
|
||||
B, T, _ = x.shape
|
||||
x = x.view(B, T, 2, -1)
|
||||
x = self.activation(x[..., 0, :]) * x[..., 1, :]
|
||||
x = quantize.linear(self.linear_out, x)
|
||||
return x
|
||||
|
||||
return gating_forward_kernel(
|
||||
self.linear_in.weight, self.linear_out.weight, self.activation, x
|
||||
)
|
||||
if isinstance(self.linear_in, nn.Linear):
|
||||
assert isinstance(self.linear_out, nn.Linear)
|
||||
with ExitStack() as stack:
|
||||
if self.training:
|
||||
stack.enter_context(no_compile())
|
||||
return gating_forward_kernel(
|
||||
self.linear_in.weight, self.linear_out.weight, self.activation, x
|
||||
)
|
||||
else:
|
||||
return gating_forward_generic(
|
||||
self.linear_in,
|
||||
self.linear_out,
|
||||
self.activation,
|
||||
x
|
||||
)
|
||||
|
||||
|
||||
def _get_activation(name: str):
|
||||
|
@ -73,7 +94,8 @@ def _get_activation(name: str):
|
|||
|
||||
|
||||
def _make_gating(
|
||||
name: str, dim: int, dim_feedforward: int, **factory_kwargs
|
||||
name: str, dim: int, dim_feedforward: int,
|
||||
**factory_kwargs
|
||||
) -> nn.Module:
|
||||
return ActivationGating(
|
||||
dim, dim_feedforward, _get_activation(name), **factory_kwargs
|
||||
|
@ -84,9 +106,10 @@ def make_gating(
|
|||
name: str, dim: int, dim_feedforward: int, **factory_kwargs
|
||||
) -> nn.Module:
|
||||
gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs)
|
||||
max_params = 2 * dim * dim_feedforward
|
||||
params = sum(p.numel() for p in gating.parameters())
|
||||
assert (
|
||||
params <= max_params
|
||||
), f"{name} gating has {params} params, max is {max_params}"
|
||||
if isinstance(gating.linear_in, nn.Linear):
|
||||
max_params = 2 * dim * dim_feedforward
|
||||
params = sum(p.numel() for p in gating.parameters())
|
||||
assert (
|
||||
params <= max_params
|
||||
), f"{name} gating has {params} params, max is {max_params}"
|
||||
return gating
|
||||
|
|
122
moshi/moshi/modules/lora.py
Normal file
122
moshi/moshi/modules/lora.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def replace_all_linear_with_lora(module, rank: int, scaling: float, device=None, dtype=None):
|
||||
""" Recursively replace all Linear layers with LoRALinear layers."""
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
if device is None:
|
||||
this_device = child.weight.device
|
||||
else:
|
||||
this_device = device
|
||||
if dtype is None:
|
||||
this_dtype = child.weight.dtype
|
||||
else:
|
||||
this_dtype = dtype
|
||||
lora = LoRALinear(child.in_features, child.out_features,
|
||||
rank, scaling, device=this_device, dtype=this_dtype)
|
||||
lora.frozen_W = child
|
||||
setattr(module, name, lora)
|
||||
else:
|
||||
replace_all_linear_with_lora(child, rank, scaling, device=device, dtype=dtype)
|
||||
|
||||
|
||||
def replace_lora_with_linear(module):
|
||||
"""Recursively replace all LoRALinear layers with Linear layers."""
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, LoRALinear):
|
||||
# Compute merged weights: W' = W + scaling * B @ A
|
||||
merged_weight = child.frozen_W.weight.data + \
|
||||
child.scaling * (child.lora_B.weight @ child.lora_A.weight)
|
||||
# Create a standard Linear layer with the same in/out features
|
||||
new_linear = nn.Linear(child.frozen_W.in_features,
|
||||
child.frozen_W.out_features, bias=False,
|
||||
device=torch.device('meta'),
|
||||
dtype=merged_weight.dtype)
|
||||
new_linear.weight = nn.Parameter(
|
||||
merged_weight, requires_grad=merged_weight.requires_grad) # Transfer merged weights
|
||||
setattr(module, name, new_linear) # Replace the module
|
||||
else:
|
||||
replace_lora_with_linear(child) # Recursively process submodules
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""
|
||||
Implementation of:
|
||||
- LoRA: https://arxiv.org/abs/2106.09685
|
||||
|
||||
Notes:
|
||||
- Freezing is handled at the network level, not the layer level.
|
||||
- Scaling factor controls relative importance of LoRA skip
|
||||
connection versus original frozen weight. General guidance is
|
||||
to keep it to 2.0 and sweep over learning rate when changing
|
||||
the rank.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int,
|
||||
scaling: float,
|
||||
bias: bool = False,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
assert not bias
|
||||
self.bias = bias
|
||||
self.rank = rank
|
||||
self.scaling = scaling
|
||||
|
||||
self.lora_A = nn.Linear(
|
||||
self.in_features,
|
||||
self.rank,
|
||||
bias=self.bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.lora_B = nn.Linear(
|
||||
self.rank,
|
||||
self.out_features,
|
||||
bias=self.bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.frozen_W = nn.Linear(self.in_features,
|
||||
self.out_features,
|
||||
bias=self.bias,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
|
||||
self._register_load_state_dict_pre_hook(LoRALinear._load_hook, with_module=True)
|
||||
|
||||
def merge_weight(self):
|
||||
with torch.no_grad():
|
||||
down_weight = self.lora_A.weight
|
||||
up_weight = self.lora_B.weight
|
||||
|
||||
weight = up_weight.mm(down_weight) * self.scaling
|
||||
|
||||
weight += self.frozen_W.weight
|
||||
return weight
|
||||
|
||||
@staticmethod
|
||||
def _load_hook(module, state_dict, prefix, *_):
|
||||
key_name = prefix + "weight"
|
||||
if key_name in state_dict:
|
||||
w_ref = state_dict.pop(key_name)
|
||||
state_dict[prefix + 'frozen_W.weight'] = w_ref
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
lora = self.lora_B(self.lora_A(x))
|
||||
return self.frozen_W(x) + lora * self.scaling
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}Linear(in_features={}, out_features={}, r={})".format(
|
||||
"LoRA", self.in_features, self.out_features, self.rank)
|
|
@ -12,25 +12,18 @@ See `StreamingTransformer` for more information.
|
|||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..utils.compile import no_compile
|
||||
from ..utils import quantize
|
||||
from ..utils.quantize import replace_linear_with_qlinear
|
||||
from .gating import make_gating
|
||||
from .rope import RotaryEmbedding
|
||||
from .streaming import StreamingModule, StreamingContainer, State
|
||||
|
||||
|
||||
def quantize_transformer(module: torch.nn.Module):
|
||||
for name, child in module.named_modules():
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
quantize.quantize_linear(child)
|
||||
elif isinstance(child, StreamingMultiheadAttention):
|
||||
quantize.quantize_param(child, 'in_proj_weight')
|
||||
from .lora import LoRALinear
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
|
||||
|
||||
class LayerNormF32(nn.LayerNorm):
|
||||
|
@ -260,6 +253,34 @@ class RingKVCache:
|
|||
return KVCacheResult(keys, values, positions)
|
||||
|
||||
|
||||
def apply_weights_per_step(modules: nn.ModuleList, schedule: list[int] | None,
|
||||
x: torch.Tensor, offset: int) -> torch.Tensor:
|
||||
"""Utility to apply a multi linear layer to the given input. A multi linear layer
|
||||
applies a different set of weight for each time step.
|
||||
|
||||
Args:
|
||||
modules (nn.ModuleList): apply weights per step.
|
||||
schedule (list[int] or None): schedule for weight sharing.
|
||||
x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
|
||||
offset (int): offset for the current time step, in particular for decoding, with
|
||||
time steps provided one by one.
|
||||
"""
|
||||
|
||||
if len(modules) == 1:
|
||||
return modules[0](x)
|
||||
|
||||
ys: list[torch.Tensor] = []
|
||||
B, T, C = x.shape
|
||||
for t in range(T):
|
||||
module_index = t + offset
|
||||
if schedule is not None:
|
||||
module_index = schedule[module_index]
|
||||
y = modules[module_index](x[:, t: t + 1])
|
||||
ys.append(y)
|
||||
out = torch.cat(ys, 1)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MHAState(State):
|
||||
kv_cache: RingKVCache
|
||||
|
@ -316,7 +337,6 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
|
|||
self.weights_per_step = weights_per_step
|
||||
self.weights_per_step_schedule = weights_per_step_schedule
|
||||
|
||||
out_dim = embed_dim
|
||||
out_dim = 3 * embed_dim
|
||||
mult = 1
|
||||
if weights_per_step:
|
||||
|
@ -325,13 +345,49 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
|
|||
mult = max(weights_per_step_schedule) + 1
|
||||
else:
|
||||
mult = weights_per_step
|
||||
in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs)
|
||||
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
||||
self.in_proj_weight = in_proj.weight
|
||||
self.in_proj_bias = in_proj.bias
|
||||
self.out_proj = nn.Linear(
|
||||
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
|
||||
self.mult = mult
|
||||
|
||||
# Split in one linear per step
|
||||
self.out_projs = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
|
||||
for _ in range(mult)
|
||||
]
|
||||
)
|
||||
self.in_projs = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(embed_dim, out_dim, bias=False, **factory_kwargs)
|
||||
for _ in range(mult)
|
||||
]
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(StreamingMultiheadAttention._load_hook, with_module=True)
|
||||
|
||||
@staticmethod
|
||||
def _load_hook(module, state_dict, prefix, *_):
|
||||
mappings = {
|
||||
'in_proj_weight': 'in_projs.{i}.weight',
|
||||
'in_proj.weight': 'in_projs.{i}.weight',
|
||||
'in_proj.lora_A.weight': 'in_projs.{i}.lora_A.weight',
|
||||
'in_proj.lora_B.weight': 'in_projs.{i}.lora_B.weight',
|
||||
'out_proj.weight': 'out_projs.{i}.weight',
|
||||
'out_proj.lora_A.weight': 'out_projs.{i}.lora_A.weight',
|
||||
'out_proj.lora_B.weight': 'out_projs.{i}.lora_B.weight',
|
||||
}
|
||||
|
||||
mult = module.mult
|
||||
# _scb suffix is for quantized data.
|
||||
for suffix in ['', '_scb']:
|
||||
for source, target in mappings.items():
|
||||
this_source = prefix + source + suffix
|
||||
if this_source in state_dict:
|
||||
weight = state_dict[this_source]
|
||||
_, *OD = weight.shape
|
||||
weight = weight.view(mult, -1, *OD)
|
||||
for i in range(mult):
|
||||
this_target = prefix + target.format(i=i) + suffix
|
||||
state_dict[this_target] = weight[i]
|
||||
state_dict.pop(this_source)
|
||||
|
||||
def _init_streaming_state(self, batch_size: int) -> _MHAState:
|
||||
if self.context is None:
|
||||
|
@ -343,13 +399,20 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
|
|||
)
|
||||
else:
|
||||
capacity = self.context
|
||||
device = self.in_proj_weight.device
|
||||
# TODO: the following estimation will not work great with FSDP.
|
||||
if quantize.is_quantized(self, 'in_proj_weight'):
|
||||
# We are running with quantization
|
||||
|
||||
in_proj = self.in_projs[0]
|
||||
if isinstance(in_proj, LoRALinear):
|
||||
device = in_proj.lora_A.weight.device
|
||||
dtype = in_proj.lora_A.weight.dtype
|
||||
elif isinstance(in_proj, nn.Linear):
|
||||
device = in_proj.weight.device
|
||||
dtype = in_proj.weight.dtype
|
||||
elif isinstance(in_proj, quantize.QLinear):
|
||||
device = in_proj.weight.device
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = self.in_proj_weight.dtype
|
||||
raise RuntimeError(f"Unknown type {type(in_proj)} for linear.")
|
||||
|
||||
dim_per_head = self.embed_dim // self.num_heads
|
||||
kv_cache = RingKVCache(
|
||||
batch_size, self.num_heads, dim_per_head, capacity, device, dtype
|
||||
|
@ -379,16 +442,12 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
|
|||
offset = state.offset
|
||||
offset_cpu = state.offset_cpu
|
||||
|
||||
if self.weights_per_step:
|
||||
projected = quantize.multi_linear(
|
||||
self.weights_per_step, self.weights_per_step_schedule,
|
||||
self, query, offset_cpu, name='in_proj_weight')
|
||||
else:
|
||||
projected = quantize.linear(self, query, 'in_proj_weight')
|
||||
projected = apply_weights_per_step(
|
||||
self.in_projs, self.weights_per_step_schedule, query, offset_cpu)
|
||||
|
||||
q, k, v = rearrange(
|
||||
projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads
|
||||
)
|
||||
|
||||
if self.rope:
|
||||
q, k = self.rope(q, k, offset, time_before_heads=False)
|
||||
|
||||
|
@ -407,12 +466,9 @@ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
|
|||
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
|
||||
|
||||
x = rearrange(x, "b h t d -> b t (h d)")
|
||||
if self.weights_per_step:
|
||||
x = quantize.multi_linear(
|
||||
self.weights_per_step, self.weights_per_step_schedule,
|
||||
self.out_proj, x, offset_cpu)
|
||||
else:
|
||||
x = quantize.linear(self.out_proj, x)
|
||||
x = apply_weights_per_step(
|
||||
self.out_projs, self.weights_per_step_schedule, x, offset_cpu)
|
||||
|
||||
if state is not None:
|
||||
state.offset.add_(T)
|
||||
state.offset_cpu += T
|
||||
|
@ -565,19 +621,11 @@ class StreamingTransformerLayer(StreamingModule[_LayerState]):
|
|||
if self.gating is None:
|
||||
assert self.linear1 is not None
|
||||
assert self.linear2 is not None
|
||||
update = quantize.linear(self.linear2, self.activation(quantize.linear(self.linear1, x)))
|
||||
update = self.linear2(self.activation(self.linear1(x)))
|
||||
else:
|
||||
if self.weights_per_step:
|
||||
assert isinstance(self.gating, nn.ModuleList)
|
||||
B, T, D = x.shape
|
||||
ys = []
|
||||
for t in range(T):
|
||||
linear_index = offset + t
|
||||
if self.weights_per_step_schedule:
|
||||
linear_index = self.weights_per_step_schedule[linear_index]
|
||||
y = self.gating[linear_index](x[:, t:t + 1])
|
||||
ys.append(y)
|
||||
update = torch.cat(ys, dim=1)
|
||||
update = apply_weights_per_step(self.gating, self.weights_per_step_schedule, x, offset)
|
||||
else:
|
||||
update = self.gating(x)
|
||||
return x_orig.to(update) + self.layer_scale_2(update)
|
||||
|
@ -645,6 +693,7 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
|
|||
betas: tp.Optional[tp.Tuple[float, float]] = None,
|
||||
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
|
||||
quantize: bool = False,
|
||||
checkpointing: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
|
@ -662,6 +711,8 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
|
|||
if self.positional_embedding in {"rope", "sin_rope"}:
|
||||
self.rope = RotaryEmbedding(max_period=max_period)
|
||||
|
||||
self.checkpointing = checkpointing
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.layers.append(
|
||||
|
@ -680,7 +731,7 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
|
|||
if quantize:
|
||||
# Quantizing layers one by one to avoid taking too much space during init.
|
||||
self.layers[-1].to(device=device, dtype=dtype)
|
||||
quantize_transformer(self.layers[-1])
|
||||
replace_linear_with_qlinear(self.layers[-1])
|
||||
|
||||
def _init_streaming_state(self, batch_size: int) -> _TransformerState:
|
||||
device = next(self.parameters()).device
|
||||
|
@ -705,7 +756,16 @@ class StreamingTransformer(StreamingModule[_TransformerState]):
|
|||
x = x + self.positional_scale * pos_emb
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, *args, **kwargs)
|
||||
if self.checkpointing:
|
||||
y = torch_checkpoint(
|
||||
layer, x, *args, use_reentrant=False,
|
||||
determinism_check='none',
|
||||
preserve_rng_state=False,
|
||||
**kwargs)
|
||||
assert isinstance(y, torch.Tensor)
|
||||
x = y
|
||||
else:
|
||||
x = layer(x, *args, **kwargs)
|
||||
|
||||
if state is not None:
|
||||
state.offset.add_(T)
|
||||
|
|
|
@ -13,7 +13,6 @@ import tarfile
|
|||
import time
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
@ -21,8 +20,6 @@ import numpy as np
|
|||
import sentencepiece
|
||||
import sphn
|
||||
import torch
|
||||
|
||||
|
||||
from .client_utils import log
|
||||
from .models import loaders, MimiModel, LMModel, LMGen
|
||||
from .run_inference import get_condition_tensors
|
||||
|
@ -71,6 +68,7 @@ class ServerState:
|
|||
if tokens is None:
|
||||
continue
|
||||
_ = self.mimi.decode(tokens[:, 1:])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
async def handle_chat(self, request):
|
||||
|
@ -190,8 +188,12 @@ def main():
|
|||
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
|
||||
help="HF repo to look into, defaults Moshiko. "
|
||||
"Use this to select a different pre-trained model.")
|
||||
parser.add_argument("--lora-weight", type=str, help="Path to a local checkpoint file for LoRA.", default=None)
|
||||
parser.add_argument("--config-path", type=str, help="Path to a local config file.", default=None)
|
||||
parser.add_argument("--cfg-coef", type=float, default=1., help="CFG coefficient.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.")
|
||||
parser.add_argument("--no_fuse_lora", action="store_false", dest="fuse_lora", default=True,
|
||||
help="Do not fuse LoRA layers intot Linear layers.")
|
||||
parser.add_argument("--half", action="store_const", const=torch.float16, default=torch.bfloat16,
|
||||
dest="dtype", help="Run inference with float16, not bfloat16, better for old GPUs.")
|
||||
parser.add_argument(
|
||||
|
@ -223,7 +225,8 @@ def main():
|
|||
|
||||
log("info", "retrieving checkpoint")
|
||||
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
|
||||
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer)
|
||||
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer,
|
||||
lora_weights=args.lora_weight, config_path=args.config_path)
|
||||
log("info", "loading mimi")
|
||||
mimi = checkpoint_info.get_mimi(device=args.device)
|
||||
log("info", "mimi loaded")
|
||||
|
@ -231,7 +234,7 @@ def main():
|
|||
text_tokenizer = checkpoint_info.get_text_tokenizer()
|
||||
|
||||
log("info", "loading moshi")
|
||||
lm = checkpoint_info.get_moshi(device=args.device, dtype=args.dtype)
|
||||
lm = checkpoint_info.get_moshi(device=args.device, dtype=args.dtype, fuse_lora=args.fuse_lora)
|
||||
log("info", "moshi loaded")
|
||||
|
||||
state = ServerState(checkpoint_info.model_type, mimi, text_tokenizer, lm, args.cfg_coef, args.device,
|
||||
|
|
|
@ -15,98 +15,48 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
|
||||
def linear(module: nn.Module, x: torch.Tensor, name='weight') -> torch.Tensor:
|
||||
import bitsandbytes as bnb # type: ignore
|
||||
if is_quantized(module, name):
|
||||
class QLinear(nn.Module):
|
||||
def __init__(self, linear: nn.Linear):
|
||||
super().__init__()
|
||||
from bitsandbytes import functional as bnbF # type: ignore
|
||||
weight = linear.weight
|
||||
assert weight.data.dtype.is_floating_point
|
||||
assert linear.bias is None
|
||||
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
|
||||
self.weight = nn.Parameter(CB, requires_grad=False)
|
||||
self.weight_scb = nn.Parameter(SCB, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
import bitsandbytes as bnb # type: ignore
|
||||
state = bnb.MatmulLtState()
|
||||
state.CB = getattr(module, name)
|
||||
state.CB = self.weight # type: ignore
|
||||
assert isinstance(state.CB, torch.Tensor)
|
||||
state.SCB = getattr(module, name + '_scb')
|
||||
state.SCB = self.weight_scb # type: ignore
|
||||
assert isinstance(state.SCB, torch.Tensor)
|
||||
if state.SCB.dtype != torch.float:
|
||||
raise RuntimeError(
|
||||
"Expected `weight_scb` to have type float, but got bfloat16. "
|
||||
"When using quantized models, care should be taken not to change the dtype of "
|
||||
"the model once initialized.")
|
||||
assert state.SCB.dtype == torch.float, state.SCB.dtype
|
||||
state.has_fp16_weights = False
|
||||
y = bnb.matmul(x.half(), state.CB, state=state)
|
||||
assert isinstance(y, torch.Tensor)
|
||||
return y
|
||||
else:
|
||||
return nn.functional.linear(x, getattr(module, name))
|
||||
|
||||
|
||||
def multi_linear(num_steps: int, schedule: list[int] | None,
|
||||
module: nn.Module, x: torch.Tensor, offset: int, name='weight') -> torch.Tensor:
|
||||
"""Utility to apply a multi linear layer to the given input. A multi linear layer
|
||||
applies a different set of weight for each time step.
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of possible time steps.
|
||||
schedule (list[int] or None): schedule for weight sharing.
|
||||
weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`.
|
||||
x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
|
||||
offset (int): offset for the current time step, in particular for decoding, with
|
||||
time steps provided one by one.
|
||||
"""
|
||||
import bitsandbytes as bnb # type: ignore
|
||||
B, T, C = x.shape
|
||||
ys: list[torch.Tensor] = []
|
||||
if is_quantized(module, name):
|
||||
weight = getattr(module, name)
|
||||
weight_scb = getattr(module, name + '_scb')
|
||||
else:
|
||||
weight = getattr(module, name)
|
||||
weight_scb = None
|
||||
assert isinstance(weight, torch.Tensor)
|
||||
|
||||
num_linear = num_steps
|
||||
if schedule is not None:
|
||||
num_linear = max(schedule) + 1
|
||||
|
||||
chout, chin = weight.shape
|
||||
weight = weight.view(num_linear, -1, chin)
|
||||
if weight_scb is not None:
|
||||
assert isinstance(weight, torch.Tensor)
|
||||
assert weight_scb.shape == (chout,), (weight_scb, chout)
|
||||
weight_scb = weight_scb.view(num_linear, -1)
|
||||
assert weight_scb.dtype == torch.float, weight_scb.dtype
|
||||
|
||||
for t in range(T):
|
||||
linear_index = t + offset
|
||||
if schedule is not None:
|
||||
linear_index = schedule[linear_index]
|
||||
if weight_scb is None:
|
||||
y = nn.functional.linear(x[:, t], weight[linear_index])
|
||||
def replace_linear_with_qlinear(module):
|
||||
"""Recursively replace all Linear layers with QLinear layers."""
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, QLinear(child))
|
||||
elif isinstance(child, QLinear):
|
||||
# Slight issue with the way we implement things: the scale param
|
||||
# might get casted with the rest of the model to bfloat16, altough
|
||||
# we most likely want to keep it as float. For the LM model we might call this function twice,
|
||||
# first layer by layer to avoid to big of a memory usage, and second, at the end
|
||||
# of the LM init, after all other modules are initialized and properly dtyped.
|
||||
# In any case that should happen before loading the state dict to avoid a loss of precision.
|
||||
child.float()
|
||||
else:
|
||||
state = bnb.MatmulLtState()
|
||||
CB = weight[linear_index]
|
||||
state.CB = CB # type: ignore
|
||||
state.SCB = weight_scb[linear_index]
|
||||
state.has_fp16_weights = False
|
||||
y = bnb.matmul(x[:, t].half(), CB, state=state)
|
||||
assert isinstance(y, torch.Tensor)
|
||||
ys.append(y)
|
||||
out = torch.stack(ys, 1)
|
||||
return out
|
||||
|
||||
|
||||
def is_quantized(module: nn.Module, name: str = 'weight'):
|
||||
return hasattr(module, name + '_scb')
|
||||
|
||||
|
||||
def quantize_param(module: nn.Module, name: str = 'weight') -> None:
|
||||
from bitsandbytes import functional as bnbF # type: ignore
|
||||
if is_quantized(module, name):
|
||||
# Due to model casting, the type of SCB might be wrong, althought
|
||||
# that would only happen during the init. Let's recast it to float.
|
||||
SCB = getattr(module, name + '_scb')
|
||||
if SCB.dtype != torch.float:
|
||||
setattr(module, name + '_scb', nn.Parameter(SCB.to(torch.float), requires_grad=False))
|
||||
return
|
||||
weight = getattr(module, name)
|
||||
assert weight.data.dtype.is_floating_point
|
||||
CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
|
||||
setattr(module, name, nn.Parameter(CB, requires_grad=False))
|
||||
setattr(module, name + '_scb', nn.Parameter(SCB, requires_grad=False))
|
||||
|
||||
|
||||
def quantize_linear(linear: nn.Module) -> None:
|
||||
assert linear.bias is None
|
||||
quantize_param(linear)
|
||||
replace_linear_with_qlinear(child)
|
||||
|
|
52
moshi/moshi/utils/utils.py
Normal file
52
moshi/moshi/utils/utils.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
|
||||
from .compile import torch_compile_lazy
|
||||
|
||||
|
||||
@torch_compile_lazy
|
||||
def cross_entropy(
|
||||
logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, dtype=torch.float32,
|
||||
logits_soft_clip: float | None = None) -> torch.Tensor:
|
||||
"""Compute cross entropy between multi-codebook targets and model's logits.
|
||||
The cross entropy is computed per codebook to provide codebook-level cross entropy.
|
||||
Valid timesteps for each of the codebook are pulled from the mask, where invalid
|
||||
timesteps are set to 0.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
|
||||
targets (torch.Tensor): Target codes, of shape [B, K, T].
|
||||
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
|
||||
dtype (type): Data type of the output cross entropy.
|
||||
logits_soft_clip (float): Clipping value for the logits to avoid numerical instability.
|
||||
Recommended value: 30.0.
|
||||
Returns:
|
||||
ce (torch.Tensor): Cross entropy [B, K, T] with type dtype.
|
||||
"""
|
||||
output_shape = targets.shape
|
||||
assert logits.shape[:-1] == targets.shape
|
||||
assert mask.shape == targets.shape
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
targets = targets.reshape(-1)
|
||||
mask = mask.reshape(-1)
|
||||
|
||||
safe_targets = torch.where(
|
||||
mask,
|
||||
targets,
|
||||
torch.zeros(1, device=targets.device, dtype=targets.dtype),
|
||||
)
|
||||
|
||||
# Chunking the conversion to float32 to avoid OOMs.
|
||||
ce_chunks = []
|
||||
for logits_chunk, targets_chunk in zip(torch.chunk(logits, 4), torch.chunk(safe_targets, 4)):
|
||||
logits_chunk = logits_chunk.to(dtype)
|
||||
if logits_soft_clip is not None:
|
||||
logits_chunk = logits_soft_clip * torch.tanh(logits_chunk / logits_soft_clip)
|
||||
log_partition = torch.logsumexp(logits_chunk, dim=-1, keepdim=True)
|
||||
|
||||
# For some reason, the PyTorch cross entropy is super slow with inputs with large cardinality (e.g. 32000)
|
||||
# so we reimplement the cross entropy ourselves...
|
||||
ce_chunks.append(log_partition - logits_chunk.gather(-1, targets_chunk[..., None]))
|
||||
ce = torch.cat(ce_chunks, dim=0)
|
||||
ce = ce[..., 0]
|
||||
ce = torch.where(mask, ce, torch.zeros(1, device=ce.device, dtype=ce.dtype))
|
||||
return ce.view(output_shape)
|
|
@ -35,6 +35,7 @@ build-backend = "setuptools.build_meta"
|
|||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pyright",
|
||||
"pytest",
|
||||
"flake8",
|
||||
"pre-commit",
|
||||
"gradio-webrtc>=0.0.18"
|
||||
|
|
BIN
moshi/tests/assets/test_lm_codes.safetensors
Normal file
BIN
moshi/tests/assets/test_lm_codes.safetensors
Normal file
Binary file not shown.
BIN
moshi/tests/assets/test_lm_model.safetensors
Normal file
BIN
moshi/tests/assets/test_lm_model.safetensors
Normal file
Binary file not shown.
BIN
moshi/tests/assets/test_lm_out.safetensors
Normal file
BIN
moshi/tests/assets/test_lm_out.safetensors
Normal file
Binary file not shown.
70
moshi/tests/test_lm.py
Normal file
70
moshi/tests/test_lm.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from pathlib import Path
|
||||
from safetensors.torch import load_file, load_model
|
||||
import torch
|
||||
|
||||
from moshi.models import lm
|
||||
from moshi.utils.utils import cross_entropy
|
||||
|
||||
|
||||
def _get_assets() -> Path:
|
||||
return Path(__file__).parent / 'assets'
|
||||
|
||||
|
||||
def _get_lm(device=None, dtype=torch.float32) -> lm.LMModel:
|
||||
torch.manual_seed(1234)
|
||||
model = lm.LMModel(
|
||||
delays=[0, 1, 2, 4],
|
||||
n_q=3,
|
||||
dep_q=3,
|
||||
card=32,
|
||||
text_card=48,
|
||||
dim=16,
|
||||
num_layers=2,
|
||||
num_heads=1,
|
||||
hidden_scale=1,
|
||||
depformer_dim=16,
|
||||
depformer_multi_linear=True,
|
||||
depformer_weights_per_step=True,
|
||||
depformer_weights_per_step_schedule=[0, 1, 1],
|
||||
depformer_low_rank_embeddings=8,
|
||||
depformer_num_heads=1,
|
||||
depformer_gating='silu',
|
||||
context=4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def test_init():
|
||||
_get_lm(dtype=torch.float32)
|
||||
_get_lm(dtype=torch.bfloat16)
|
||||
_get_lm(dtype=torch.float16)
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def test_forward():
|
||||
model = _get_lm()
|
||||
load_model(model, _get_assets() / 'test_lm_model.safetensors')
|
||||
codes = load_file(_get_assets() / 'test_lm_codes.safetensors')['codes']
|
||||
out = model(codes)
|
||||
assert out.logits is not None
|
||||
assert out.text_logits is not None
|
||||
assert out.mask.shape == codes[:, 1:].shape
|
||||
assert out.text_mask.shape == codes[:, :1].shape
|
||||
assert out.logits.shape[:-1] == codes[:, 1:].shape
|
||||
assert out.logits.shape[-1] == model.card
|
||||
assert out.text_logits.shape[-1] == model.text_card
|
||||
|
||||
ref_out = load_file(_get_assets() / 'test_lm_out.safetensors')
|
||||
assert (ref_out['mask'] == out.mask).all()
|
||||
assert (ref_out['text_mask'] == out.text_mask).all()
|
||||
ce = cross_entropy(out.logits, codes[:, 1:], out.mask)
|
||||
ce_ref = cross_entropy(ref_out['logits'], codes[:, 1:], out.mask)
|
||||
delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2))
|
||||
assert delta.amax() <= 1e-6, delta.amax()
|
||||
|
||||
ce = cross_entropy(out.text_logits, codes[:, :1], out.text_mask)
|
||||
ce_ref = cross_entropy(ref_out['text_logits'], codes[:, :1], out.text_mask)
|
||||
delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2))
|
||||
assert delta.amax() <= 1e-6, delta.amax()
|
|
@ -13,7 +13,7 @@ from huggingface_hub import HfApi, hf_hub_download
|
|||
from safetensors.torch import save_file
|
||||
|
||||
from moshi.models import loaders
|
||||
from moshi.modules.transformer import quantize_transformer
|
||||
from moshi.utils.quantize import replace_linear_with_qlinear
|
||||
|
||||
|
||||
def get_api():
|
||||
|
@ -35,9 +35,9 @@ def main():
|
|||
print("Downloading base model.")
|
||||
info = loaders.CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||
print("Creating model.")
|
||||
model = info.get_moshi(device='cuda')
|
||||
model = info.get_moshi(fuse_lora=True, device='cuda')
|
||||
print("Quantizing model.")
|
||||
quantize_transformer(model)
|
||||
replace_linear_with_qlinear(model)
|
||||
|
||||
if args.new_hf_repo is None:
|
||||
new_repo = repo.rsplit('-', 1)[0] + '-q8'
|
||||
|
|
142
scripts/import_mlx_lora.py
Normal file
142
scripts/import_mlx_lora.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) Kyutai, all rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from moshi.models import loaders
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
def import_model(
|
||||
tch_model,
|
||||
out_path: Path,
|
||||
silent: bool = False,
|
||||
max_out_n_q: int | None = None,
|
||||
) -> None:
|
||||
in_n_q: int | None = None
|
||||
for idx in range(999):
|
||||
name = f"emb.{idx}.weight"
|
||||
if name not in tch_model:
|
||||
in_n_q = idx
|
||||
break
|
||||
out_n_q: int | None = None
|
||||
for idx in range(999):
|
||||
name = f"linears.{idx}.weight"
|
||||
if name not in tch_model:
|
||||
out_n_q = idx
|
||||
break
|
||||
assert in_n_q is not None
|
||||
assert out_n_q is not None
|
||||
if not silent:
|
||||
print(f"in_n_q: {in_n_q}, out_n_q: {out_n_q}")
|
||||
|
||||
depformer_layers: int | None = None
|
||||
for idx in range(999):
|
||||
if f"depformer.layers.{idx}.self_attn.in_projs.0.weight" not in tch_model:
|
||||
depformer_layers = idx
|
||||
break
|
||||
assert depformer_layers is not None
|
||||
if not silent:
|
||||
print(f"depformer layers: {depformer_layers}")
|
||||
|
||||
model = {}
|
||||
for name in ["text_emb.weight", "text_linear.weight"]:
|
||||
model[name] = tch_model[name]
|
||||
for name in tch_model.keys():
|
||||
if name.startswith("condition_provider.conditioners"):
|
||||
model[name] = tch_model[name]
|
||||
model["out_norm.weight"] = tch_model["out_norm.alpha"][0, 0]
|
||||
for idx in range(in_n_q):
|
||||
src_name = f"emb.{idx}.weight"
|
||||
dst_name = f"audio_embs.{idx}.weight"
|
||||
model[dst_name] = tch_model[src_name]
|
||||
|
||||
for k, v in sorted(tch_model.items()):
|
||||
print(k, v.shape, v.dtype)
|
||||
if k.startswith("transformer"):
|
||||
if k.endswith(".alpha"):
|
||||
v = v[0, 0]
|
||||
k = k.replace(".alpha", ".weight")
|
||||
k = k.replace(".in_projs.0.weight", ".in_proj.weight")
|
||||
k = k.replace(".out_projs.0.weight", ".out_proj.weight")
|
||||
model[k] = v
|
||||
|
||||
# Only export the first slices of the depformer (main).
|
||||
if max_out_n_q is not None:
|
||||
exported_out_n_q = min(max_out_n_q, out_n_q)
|
||||
print(f"only exporting the first {exported_out_n_q} depformer layers")
|
||||
else:
|
||||
exported_out_n_q = out_n_q
|
||||
|
||||
for idx in range(exported_out_n_q):
|
||||
tch_idx = idx
|
||||
|
||||
base = f"depformer.slices.{idx}."
|
||||
model[base + "linear_in.weight"] = tch_model[f"depformer_in.{tch_idx}.weight"].clone()
|
||||
model[base + "linear_out.weight"] = tch_model[f"linears.{idx}.weight"]
|
||||
if idx == 0:
|
||||
model[base + "emb.weight"] = tch_model["depformer_text_emb.weight"]
|
||||
if "depformer_text_emb.low_rank.weight" in tch_model:
|
||||
model[base + "emb.low_rank.weight"] = tch_model["depformer_text_emb.low_rank.weight"].clone()
|
||||
else:
|
||||
model[base + "emb.weight"] = tch_model[f"depformer_emb.{idx-1}.weight"].clone()
|
||||
if f"depformer_emb.{tch_idx-1}.low_rank.weight" in tch_model:
|
||||
model[base + "emb.low_rank.weight"] = tch_model[f"depformer_emb.{idx-1}.low_rank.weight"].clone()
|
||||
|
||||
for layer_idx in range(depformer_layers):
|
||||
layer = base + f"transformer.layers.{layer_idx}."
|
||||
model[layer + "self_attn.in_proj.weight"] = (
|
||||
tch_model[f"depformer.layers.{layer_idx}.self_attn.in_projs.{tch_idx}.weight"]
|
||||
)
|
||||
model[layer + "self_attn.out_proj.weight"] = (
|
||||
tch_model[f"depformer.layers.{layer_idx}.self_attn.out_projs.{tch_idx}.weight"]
|
||||
)
|
||||
model[layer + "norm1.weight"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.norm1.alpha"
|
||||
][0, 0].clone()
|
||||
model[layer + "norm2.weight"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.norm2.alpha"
|
||||
][0, 0].clone()
|
||||
model[layer + "gating.linear_in.weight"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_in.weight"
|
||||
].clone()
|
||||
model[layer + "gating.linear_out.weight"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_out.weight"
|
||||
].clone()
|
||||
|
||||
save_file(model, out_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
|
||||
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
|
||||
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
|
||||
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
|
||||
help="HF repo to look into, defaults Moshiko. "
|
||||
"Use this to select a different pre-trained model.")
|
||||
parser.add_argument("--lora-weight", type=str, help="Path to a local checkpoint file for LoRA.", default=None)
|
||||
parser.add_argument("--config-path", type=str, help="Path to a local config file.", default=None)
|
||||
parser.add_argument(
|
||||
"-s", "--silent", action="store_true", help="only prints the checkpoint name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-out-n-q",
|
||||
type=int,
|
||||
help="limit the number of depformer layers that are exported",
|
||||
)
|
||||
parser.add_argument("out", type=str, help="the mlx safetensors file to generate")
|
||||
|
||||
args = parser.parse_args()
|
||||
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
|
||||
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer,
|
||||
lora_weights=args.lora_weight, config_path=args.config_path)
|
||||
lm = checkpoint_info.get_moshi(device="cpu", dtype=torch.bfloat16, fuse_lora=True)
|
||||
for key, value in lm.state_dict().items():
|
||||
print(key, value.shape)
|
||||
out_path = Path(args.out)
|
||||
import_model(lm.state_dict(), out_path, silent=args.silent, max_out_n_q=args.max_out_n_q)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
139
scripts/import_rust_lora.py
Normal file
139
scripts/import_rust_lora.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Kyutai, all rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from moshi.models import loaders
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
def import_model(
|
||||
tch_model,
|
||||
max_out_n_q: int | None,
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
for k, v in sorted(tch_model.items()):
|
||||
print(k, v.shape, v.dtype)
|
||||
|
||||
in_n_q: int | None = None
|
||||
for idx in range(999):
|
||||
name = f"emb.{idx}.weight"
|
||||
if name not in tch_model:
|
||||
in_n_q = idx
|
||||
break
|
||||
out_n_q: int | None = None
|
||||
for idx in range(999):
|
||||
name = f"linears.{idx}.weight"
|
||||
if name not in tch_model:
|
||||
out_n_q = idx
|
||||
break
|
||||
assert in_n_q is not None
|
||||
assert out_n_q is not None
|
||||
|
||||
depformer_layers: int | None = None
|
||||
for idx in range(999):
|
||||
if f"depformer.layers.{idx}.self_attn.in_projs.0.weight" not in tch_model:
|
||||
depformer_layers = idx
|
||||
break
|
||||
assert depformer_layers is not None
|
||||
print(f"depformer layers: {depformer_layers}")
|
||||
|
||||
model = {}
|
||||
for name in ["text_emb.weight", "text_linear.weight", "out_norm.alpha"]:
|
||||
model[name] = tch_model[name]
|
||||
for name in tch_model.keys():
|
||||
if name.startswith("condition_provider.conditioners"):
|
||||
model[name] = tch_model[name]
|
||||
for idx in range(in_n_q):
|
||||
name = f"emb.{idx}.weight"
|
||||
model[name] = tch_model[name]
|
||||
|
||||
for k, v in sorted(tch_model.items()):
|
||||
if k.startswith("transformer"):
|
||||
k = k.replace(".in_projs.0.weight", ".in_proj.weight")
|
||||
k = k.replace(".out_projs.0.weight", ".out_proj.weight")
|
||||
model[k] = v
|
||||
|
||||
# Only export the first slices of the depformer (main).
|
||||
if max_out_n_q is not None:
|
||||
exported_out_n_q = min(max_out_n_q, out_n_q)
|
||||
print(f"only exporting the first {exported_out_n_q} depformer layers")
|
||||
else:
|
||||
print(f"exporting all {out_n_q} depformer layers")
|
||||
exported_out_n_q = out_n_q
|
||||
|
||||
max_df_steps = out_n_q
|
||||
|
||||
for idx in range(exported_out_n_q):
|
||||
tch_idx = idx
|
||||
base = f"depformer.{idx}."
|
||||
model[base + "linear_in.weight"] = tch_model[f"depformer_in.{tch_idx}.weight"].clone()
|
||||
model[base + "linear_out.weight"] = tch_model[f"linears.{idx}.weight"]
|
||||
if idx == 0:
|
||||
model[base + "emb.weight"] = tch_model["depformer_text_emb.weight"]
|
||||
if "depformer_text_emb.low_rank.weight" in tch_model:
|
||||
model[base + "emb.low_rank.weight"] = tch_model["depformer_text_emb.low_rank.weight"].clone()
|
||||
else:
|
||||
model[base + "emb.weight"] = tch_model[f"depformer_emb.{tch_idx-1}.weight"].clone()
|
||||
if f"depformer_emb.{tch_idx-1}.low_rank.weight" in tch_model:
|
||||
model[base + "emb.low_rank.weight"] = tch_model[f"depformer_emb.{tch_idx-1}.low_rank.weight"].clone()
|
||||
|
||||
for layer_idx in range(depformer_layers):
|
||||
layer = base + f"transformer.layers.{layer_idx}."
|
||||
# WARNING: note that this uses in_proj_weight vs out_proj.weight
|
||||
model[layer + "self_attn.in_proj_weight"] = (
|
||||
tch_model[f"depformer.layers.{layer_idx}.self_attn.in_projs.{tch_idx}.weight"]
|
||||
)
|
||||
model[layer + "self_attn.out_proj.weight"] = (
|
||||
tch_model[f"depformer.layers.{layer_idx}.self_attn.out_projs.{tch_idx}.weight"]
|
||||
)
|
||||
model[layer + "norm1.alpha"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.norm1.alpha"
|
||||
].clone()
|
||||
model[layer + "norm2.alpha"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.norm2.alpha"
|
||||
].clone()
|
||||
model[layer + "gating.linear_in.weight"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_in.weight"
|
||||
].clone()
|
||||
model[layer + "gating.linear_out.weight"] = tch_model[
|
||||
f"depformer.layers.{layer_idx}.gating.{tch_idx}.linear_out.weight"
|
||||
].clone()
|
||||
|
||||
save_file(model, out_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
|
||||
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
|
||||
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
|
||||
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
|
||||
help="HF repo to look into, defaults Moshiko. "
|
||||
"Use this to select a different pre-trained model.")
|
||||
parser.add_argument("--lora-weight", type=str, help="Path to a local checkpoint file for LoRA.", default=None)
|
||||
parser.add_argument("--config-path", type=str, help="Path to a local config file.", default=None)
|
||||
parser.add_argument(
|
||||
"-s", "--silent", action="store_true", help="only prints the checkpoint name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-out-n-q",
|
||||
type=int,
|
||||
help="limit the number of depformer layers that are exported",
|
||||
)
|
||||
parser.add_argument("out", type=str, help="the rust safetensors file to generate")
|
||||
|
||||
args = parser.parse_args()
|
||||
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
|
||||
args.hf_repo, args.moshi_weight, args.mimi_weight, args.tokenizer,
|
||||
lora_weights=args.lora_weight, config_path=args.config_path)
|
||||
lm = checkpoint_info.get_moshi(device="cpu", dtype=torch.bfloat16, fuse_lora=True)
|
||||
for key, value in lm.state_dict().items():
|
||||
print(key, value.shape)
|
||||
out_path = Path(args.out)
|
||||
import_model(lm.state_dict(), out_path=out_path, max_out_n_q=args.max_out_n_q)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue