diff --git a/tortoise_tts/emb/mel.py b/tortoise_tts/emb/mel.py index 6af96a1..4084829 100755 --- a/tortoise_tts/emb/mel.py +++ b/tortoise_tts/emb/mel.py @@ -80,6 +80,8 @@ def format_diffusion_conditioning( sample, device, do_normalization=False ): # encode a wav to conditioning latents + mel codes @torch.inference_mode() def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"): + wav = torchaudio.functional.resample(wav, sr, 22050) + dvae = load_model("dvae", device=device) unified_voice = load_model("unified_voice", device=device) diffusion = load_model("diffusion", device=device) diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py index eb45ac1..444cfd8 100755 --- a/tortoise_tts/train.py +++ b/tortoise_tts/train.py @@ -4,7 +4,7 @@ from .config import cfg from .data import create_train_val_dataloader from .emb import mel -from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc +from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc, wrapper as ml from .utils.distributed import is_global_leader import auraloss @@ -165,22 +165,24 @@ def run_eval(engines, eval_name, dl): break # diffusion pass - output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. - output_shape = (latents.shape[0], 100, output_seq_len) - precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) + with ml.auto_unload(diffusion, enabled=True): + output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) - noise = torch.randn(output_shape, device=latents.device) * temperature - mel = diffuser.p_sample_loop( - diffusion, - output_shape, - noise=noise, - model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, - progress=True - ) - mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.p_sample_loop( + diffusion, + output_shape, + noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=True + ) + mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] # vocoder pass - wavs = vocoder.inference(mels) + with ml.auto_unload(vocoder, enabled=True): + wavs = vocoder.inference(mels) return wavs