diff --git a/requirements.txt b/requirements.txt index 594507f..c96ba68 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,10 +4,10 @@ transformers==4.19 tokenizers inflect progressbar -einops==0.6.0 +einops unidecode scipy -librosa==0.8.0 +librosa==0.8.1 torchaudio threadpoolctl appdirs diff --git a/tortoise/api.py b/tortoise/api.py index 2df47ea..f1a60f2 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -442,19 +442,7 @@ class TextToSpeech: beta=8.555504641634386, ).to(device) - samples = [] - auto_conds = [] - for sample in voice_samples: - auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate)) - samples.append(resampler(sample)) - - auto_conds = torch.stack(auto_conds, dim=1) - - self.autoregressive = migrate_to_device( self.autoregressive, device ) - auto_latent = self.autoregressive.get_conditioning(auto_conds) - self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' ) - - diffusion_conds = [] + samples = [resampler(sample) for sample in voice_samples] chunks = [] concat = torch.cat(samples, dim=-1) @@ -469,15 +457,22 @@ class TextToSpeech: chunks = torch.chunk(concat, slices, dim=1) chunk_size = chunks[0].shape[-1] + + auto_conds = [] + for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing AR conditioning latents..."): + auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size)) + auto_conds = torch.stack(auto_conds, dim=1) + + self.autoregressive = migrate_to_device( self.autoregressive, device ) + auto_latent = self.autoregressive.get_conditioning(auto_conds) + self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' ) - # expand / truncate samples to match the common size - # required, as tensors need to be of the same length - for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): + diffusion_conds = [] + for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing diffusion conditioning latents..."): check_for_kill_signal() chunk = pad_or_truncate(chunk, chunk_size) cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device) diffusion_conds.append(cond_mel) - diffusion_conds = torch.stack(diffusion_conds, dim=1) self.diffusion = migrate_to_device( self.diffusion, device )