plop (#319)
This commit is contained in:
parent
7b684410aa
commit
192b9d82d4
5 changed files with 15 additions and 25 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -182,4 +182,5 @@ client/node_modules
|
|||
timings.json
|
||||
mlx-trace.json
|
||||
/scripts/token.txt
|
||||
uv.lock
|
||||
tts-outputs
|
||||
|
|
2
data/tts.jsonl
Normal file
2
data/tts.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
|||
{"id": "hello", "turns": ["Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."], "voices": ["expresso/ex03-ex01_happy_001_channel1_334s.wav"]}
|
||||
{"id": "ceci_test", "turns": ["Bonjour, ceci est un test. Merci de bien vouloir respecter la consigne. Nous allons aussi synthétiser une voix en français."], "voices": ["unmute-prod-website/developpeuse-3.wav"]}
|
|
@ -410,6 +410,8 @@ class TTSModel:
|
|||
machine=machine, delay_steps=delay_steps,
|
||||
**kwargs)
|
||||
mimi.set_num_codebooks(tts_model.n_q)
|
||||
if not tts_model.multi_speaker:
|
||||
tts_model.voice_suffix = ''
|
||||
return tts_model
|
||||
|
||||
@cached_property
|
||||
|
@ -454,27 +456,6 @@ class TTSModel:
|
|||
**kwargs: passed to `moshi.models.lm.LMGen`.
|
||||
"""
|
||||
|
||||
def _main_wrapper(*args, **kwargs):
|
||||
transformer_out, text_logits = original(*args, **kwargs)
|
||||
if self.padding_bonus:
|
||||
text_logits[..., self.machine.token_ids.pad] += self.padding_bonus
|
||||
return transformer_out, text_logits
|
||||
|
||||
original = self.lm.forward_text
|
||||
self.lm.forward_text = _main_wrapper
|
||||
|
||||
try:
|
||||
return self._generate(all_entries, attributes, prefixes,
|
||||
cfg_is_no_prefix, cfg_is_no_text, **kwargs)
|
||||
finally:
|
||||
self.lm.forward_text = original
|
||||
|
||||
def _generate(self, all_entries: tp.Sequence[tp.Sequence[Entry]],
|
||||
attributes: tp.Sequence[ConditionAttributes],
|
||||
prefixes: list[torch.Tensor] | None = None,
|
||||
cfg_is_no_prefix: bool = True, cfg_is_no_text: bool = True,
|
||||
**kwargs):
|
||||
|
||||
if self.cfg_coef != 1.0:
|
||||
if self.valid_cfg_conditionings:
|
||||
raise ValueError(
|
||||
|
@ -513,6 +494,11 @@ class TTSModel:
|
|||
delayed = delayed.to(device)
|
||||
audio_prefixes.append(deque(delayed.t()))
|
||||
|
||||
def _on_text_logits_hook(text_logits):
|
||||
if self.padding_bonus:
|
||||
text_logits[..., self.machine.token_ids.pad] += self.padding_bonus
|
||||
return text_logits
|
||||
|
||||
def _on_audio_hook(audio_tokens):
|
||||
audio_offset = self.lm.audio_offset
|
||||
delays = self.lm.delays
|
||||
|
@ -544,7 +530,7 @@ class TTSModel:
|
|||
lm_gen = LMGen(
|
||||
self.lm, temp=self.temp, temp_text=self.temp,
|
||||
cfg_coef=self.cfg_coef, condition_tensors=condition_tensors,
|
||||
on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook,
|
||||
on_text_logits_hook=_on_text_logits_hook, on_text_hook=_on_text_hook, on_audio_hook=_on_audio_hook,
|
||||
cfg_is_masked_until=cfg_is_masked_until, cfg_is_no_text=cfg_is_no_text,
|
||||
**kwargs)
|
||||
|
||||
|
@ -621,7 +607,7 @@ class TTSModel:
|
|||
raise ValueError(f"Unsupported value for cfg_coef, valid values are {valids}.")
|
||||
return ConditionAttributes(text=text, tensor=tensors)
|
||||
|
||||
def get_prefix(self, audio_path: Path | str) -> torch.Tensor:
|
||||
def get_prefix(self, audio_path: Path) -> torch.Tensor:
|
||||
wav, _ = sphn.read(audio_path, sample_rate=self.mimi.sample_rate)
|
||||
with torch.no_grad():
|
||||
prefix = self.mimi.encode(torch.from_numpy(wav).to(device=self.lm.device)[None])[0, :, :-2]
|
||||
|
|
|
@ -15,7 +15,7 @@ import sphn
|
|||
import torch
|
||||
|
||||
from .models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
|
||||
from .models.loaders import CheckpointInfo, hf_get
|
||||
from .models.loaders import CheckpointInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -118,7 +118,7 @@ def main():
|
|||
all_attributes.append(tts_model.make_condition_attributes(voices, cfg_coef_conditioning))
|
||||
if prefixes is not None:
|
||||
assert len(request.voices) == 1, "For this model, only exactly one voice is supported."
|
||||
prefix_path = hf_get(request.voices[0], args.voice_repo, check_local_file_exists=True)
|
||||
prefix_path = tts_model.get_voice_path(request.voices[0])
|
||||
prefixes.append(tts_model.get_prefix(prefix_path))
|
||||
|
||||
print(f"Starting batch of size {len(batch)}")
|
||||
|
|
|
@ -7,4 +7,5 @@ ignore = E203,E704
|
|||
exclude =
|
||||
dist
|
||||
build
|
||||
.venv
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue