why didn't I also have it use chunks for computing the AR conditional latents (instead of just the diffusion aspect)

This commit is contained in:
mrq 2023-03-14 01:13:49 +00:00
parent 97cd58e7eb
commit 65a43deb9e
2 changed files with 14 additions and 19 deletions

View File

@ -4,10 +4,10 @@ transformers==4.19
tokenizers
inflect
progressbar
einops==0.6.0
einops
unidecode
scipy
librosa==0.8.0
librosa==0.8.1
torchaudio
threadpoolctl
appdirs

View File

@ -442,19 +442,7 @@ class TextToSpeech:
beta=8.555504641634386,
).to(device)
samples = []
auto_conds = []
for sample in voice_samples:
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate))
samples.append(resampler(sample))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
diffusion_conds = []
samples = [resampler(sample) for sample in voice_samples]
chunks = []
concat = torch.cat(samples, dim=-1)
@ -469,15 +457,22 @@ class TextToSpeech:
chunks = torch.chunk(concat, slices, dim=1)
chunk_size = chunks[0].shape[-1]
auto_conds = []
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
# expand / truncate samples to match the common size
# required, as tensors need to be of the same length
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
diffusion_conds = []
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = migrate_to_device( self.diffusion, device )