This commit is contained in:
Alexandre Défossez 2025-07-02 15:08:16 +02:00 committed by GitHub
parent 7b684410aa
commit 192b9d82d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 15 additions and 25 deletions

1
.gitignore vendored
View file

@ -182,4 +182,5 @@ client/node_modules
timings.json
mlx-trace.json
/scripts/token.txt
uv.lock
tts-outputs

2
data/tts.jsonl Normal file
View file

@ -0,0 +1,2 @@
{"id": "hello", "turns": ["Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."], "voices": ["expresso/ex03-ex01_happy_001_channel1_334s.wav"]}
{"id": "ceci_test", "turns": ["Bonjour, ceci est un test. Merci de bien vouloir respecter la consigne. Nous allons aussi synthétiser une voix en français."], "voices": ["unmute-prod-website/developpeuse-3.wav"]}

View file

@ -410,6 +410,8 @@ class TTSModel:
machine=machine, delay_steps=delay_steps,
**kwargs)
mimi.set_num_codebooks(tts_model.n_q)
if not tts_model.multi_speaker:
tts_model.voice_suffix = ''
return tts_model
@cached_property
@ -454,27 +456,6 @@ class TTSModel:
**kwargs: passed to `moshi.models.lm.LMGen`.
"""
def _main_wrapper(*args, **kwargs):
transformer_out, text_logits = original(*args, **kwargs)
if self.padding_bonus:
text_logits[..., self.machine.token_ids.pad] += self.padding_bonus
return transformer_out, text_logits
original = self.lm.forward_text
self.lm.forward_text = _main_wrapper
try:
return self._generate(all_entries, attributes, prefixes,
cfg_is_no_prefix, cfg_is_no_text, **kwargs)
finally:
self.lm.forward_text = original
def _generate(self, all_entries: tp.Sequence[tp.Sequence[Entry]],
attributes: tp.Sequence[ConditionAttributes],
prefixes: list[torch.Tensor] | None = None,
cfg_is_no_prefix: bool = True, cfg_is_no_text: bool = True,
**kwargs):
if self.cfg_coef != 1.0:
if self.valid_cfg_conditionings:
raise ValueError(
@ -513,6 +494,11 @@ class TTSModel:
delayed = delayed.to(device)
audio_prefixes.append(deque(delayed.t()))
def _on_text_logits_hook(text_logits):
if self.padding_bonus:
text_logits[..., self.machine.token_ids.pad] += self.padding_bonus
return text_logits
def _on_audio_hook(audio_tokens):
audio_offset = self.lm.audio_offset
delays = self.lm.delays
@ -544,7 +530,7 @@ class TTSModel:
lm_gen = LMGen(
self.lm, temp=self.temp, temp_text=self.temp,
cfg_coef=self.cfg_coef, condition_tensors=condition_tensors,
on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook,
on_text_logits_hook=_on_text_logits_hook, on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook,
cfg_is_masked_until=cfg_is_masked_until, cfg_is_no_text=cfg_is_no_text,
**kwargs)
@ -621,7 +607,7 @@ class TTSModel:
raise ValueError(f"Unsupported value for cfg_coef, valid values are {valids}.")
return ConditionAttributes(text=text, tensor=tensors)
def get_prefix(self, audio_path: Path | str) -> torch.Tensor:
def get_prefix(self, audio_path: Path) -> torch.Tensor:
wav, _ = sphn.read(audio_path, sample_rate=self.mimi.sample_rate)
with torch.no_grad():
prefix = self.mimi.encode(torch.from_numpy(wav).to(device=self.lm.device)[None])[0, :, :-2]

View file

@ -15,7 +15,7 @@ import sphn
import torch
from .models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
from .models.loaders import CheckpointInfo, hf_get
from .models.loaders import CheckpointInfo
@dataclass
@ -118,7 +118,7 @@ def main():
all_attributes.append(tts_model.make_condition_attributes(voices, cfg_coef_conditioning))
if prefixes is not None:
assert len(request.voices) == 1, "For this model, only exactly one voice is supported."
prefix_path = hf_get(request.voices[0], args.voice_repo, check_local_file_exists=True)
prefix_path = tts_model.get_voice_path(request.voices[0])
prefixes.append(tts_model.get_prefix(prefix_path))
print(f"Starting batch of size {len(batch)}")

View file

@ -7,4 +7,5 @@ ignore = E203,E704
exclude =
dist
build
.venv