added flags to rever to default method of latent generation (separately for the AR and Diffusion latents, as some voices don't play nicely with the chunk-for-all method)

This commit is contained in:
mrq 2023-05-21 01:46:55 +00:00
parent c90ee7c529
commit 5ff00bf3bf
2 changed files with 54 additions and 31 deletions

View File

@ -11,5 +11,5 @@ librosa==0.8.1
torchaudio torchaudio
threadpoolctl threadpoolctl
appdirs appdirs
numpy==1.23.5 numpy<=1.23.5
numba numba

View File

@ -448,13 +448,14 @@ class TextToSpeech:
if self.preloaded_tensors: if self.preloaded_tensors:
self.cvvp = migrate_to_device( self.cvvp, self.device ) self.cvvp = migrate_to_device( self.cvvp, self.device )
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False): def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False, original_ar=False, original_diffusion=False):
""" """
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
properties. properties.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
""" """
with torch.no_grad(): with torch.no_grad():
# computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions # computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions
if get_device_name() == "dml": if get_device_name() == "dml":
@ -464,20 +465,35 @@ class TextToSpeech:
if not isinstance(voice_samples, list): if not isinstance(voice_samples, list):
voice_samples = [voice_samples] voice_samples = [voice_samples]
voice_samples = [migrate_to_device(v, device) for v in voice_samples] resampler_22K = torchaudio.transforms.Resample(
resampler = torchaudio.transforms.Resample(
self.input_sample_rate, self.input_sample_rate,
self.output_sample_rate, 22050,
lowpass_filter_width=16, lowpass_filter_width=16,
rolloff=0.85, rolloff=0.85,
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=8.555504641634386, beta=8.555504641634386,
).to(device) ).to(device)
samples = [resampler(sample) for sample in voice_samples] resampler_24K = torchaudio.transforms.Resample(
chunks = [] self.input_sample_rate,
24000,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
).to(device)
voice_samples = [migrate_to_device(v, device) for v in voice_samples]
auto_conds = []
diffusion_conds = []
if original_ar:
samples = [resampler_22K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate, cond_length=132300))
else:
samples = [resampler_22K(sample) for sample in voice_samples]
concat = torch.cat(samples, dim=-1) concat = torch.cat(samples, dim=-1)
chunk_size = concat.shape[-1] chunk_size = concat.shape[-1]
@ -491,23 +507,30 @@ class TextToSpeech:
chunks = torch.chunk(concat, slices, dim=1) chunks = torch.chunk(concat, slices, dim=1)
chunk_size = chunks[0].shape[-1] chunk_size = chunks[0].shape[-1]
auto_conds = []
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."): for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size)) auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
diffusion_conds = [] if original_diffusion:
samples = [resampler_24K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing diffusion conditioning latents..."):
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(migrate_to_device(sample, device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel)
else:
samples = [resampler_24K(sample) for sample in voice_samples]
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."): for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal() check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size) chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device) cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = migrate_to_device( self.diffusion, device ) self.diffusion = migrate_to_device( self.diffusion, device )
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' ) self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )