added flags to rever to default method of latent generation (separately for the AR and Diffusion latents, as some voices don't play nicely with the chunk-for-all method)

remotes/1710274000886183304/main
mrq 2023-05-21 01:46:55 +07:00
parent c90ee7c529
commit 5ff00bf3bf
2 changed files with 52 additions and 29 deletions

@ -11,5 +11,5 @@ librosa==0.8.1
torchaudio torchaudio
threadpoolctl threadpoolctl
appdirs appdirs
numpy==1.23.5 numpy<=1.23.5
numba numba

@ -448,13 +448,14 @@ class TextToSpeech:
if self.preloaded_tensors: if self.preloaded_tensors:
self.cvvp = migrate_to_device( self.cvvp, self.device ) self.cvvp = migrate_to_device( self.cvvp, self.device )
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False): def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False, original_ar=False, original_diffusion=False):
""" """
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
properties. properties.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
""" """
with torch.no_grad(): with torch.no_grad():
# computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions # computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions
if get_device_name() == "dml": if get_device_name() == "dml":
@ -464,50 +465,72 @@ class TextToSpeech:
if not isinstance(voice_samples, list): if not isinstance(voice_samples, list):
voice_samples = [voice_samples] voice_samples = [voice_samples]
voice_samples = [migrate_to_device(v, device) for v in voice_samples] resampler_22K = torchaudio.transforms.Resample(
self.input_sample_rate,
22050,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
).to(device)
resampler = torchaudio.transforms.Resample( resampler_24K = torchaudio.transforms.Resample(
self.input_sample_rate, self.input_sample_rate,
self.output_sample_rate, 24000,
lowpass_filter_width=16, lowpass_filter_width=16,
rolloff=0.85, rolloff=0.85,
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=8.555504641634386, beta=8.555504641634386,
).to(device) ).to(device)
samples = [resampler(sample) for sample in voice_samples] voice_samples = [migrate_to_device(v, device) for v in voice_samples]
chunks = []
concat = torch.cat(samples, dim=-1) auto_conds = []
chunk_size = concat.shape[-1] diffusion_conds = []
if slices == 0: if original_ar:
slices = 1 samples = [resampler_22K(sample) for sample in voice_samples]
elif max_chunk_size is not None and chunk_size > max_chunk_size: for sample in tqdm(samples, desc="Computing AR conditioning latents..."):
slices = 1 auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate, cond_length=132300))
while int(chunk_size / slices) > max_chunk_size: else:
slices = slices + 1 samples = [resampler_22K(sample) for sample in voice_samples]
concat = torch.cat(samples, dim=-1)
chunk_size = concat.shape[-1]
if slices == 0:
slices = 1
elif max_chunk_size is not None and chunk_size > max_chunk_size:
slices = 1
while int(chunk_size / slices) > max_chunk_size:
slices = slices + 1
chunks = torch.chunk(concat, slices, dim=1)
chunk_size = chunks[0].shape[-1]
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
chunks = torch.chunk(concat, slices, dim=1) if original_diffusion:
chunk_size = chunks[0].shape[-1] samples = [resampler_24K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing diffusion conditioning latents..."):
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(migrate_to_device(sample, device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel)
else:
samples = [resampler_24K(sample) for sample in voice_samples]
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
auto_conds = []
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
auto_conds = torch.stack(auto_conds, dim=1) auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device ) self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds) auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' ) self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
diffusion_conds = []
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = migrate_to_device( self.diffusion, device ) self.diffusion = migrate_to_device( self.diffusion, device )
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' ) self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )