1
1
forked from mrq/tortoise-tts

added flags to rever to default method of latent generation (separately for the AR and Diffusion latents, as some voices don't play nicely with the chunk-for-all method)

This commit is contained in:
mrq 2023-05-21 01:46:55 +00:00
parent c90ee7c529
commit 5ff00bf3bf
2 changed files with 54 additions and 31 deletions

View File

@ -11,5 +11,5 @@ librosa==0.8.1
torchaudio
threadpoolctl
appdirs
numpy==1.23.5
numpy<=1.23.5
numba

View File

@ -448,13 +448,14 @@ class TextToSpeech:
if self.preloaded_tensors:
self.cvvp = migrate_to_device( self.cvvp, self.device )
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False):
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False, original_ar=False, original_diffusion=False):
"""
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
properties.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
"""
with torch.no_grad():
# computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions
if get_device_name() == "dml":
@ -464,50 +465,72 @@ class TextToSpeech:
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
voice_samples = [migrate_to_device(v, device) for v in voice_samples]
resampler = torchaudio.transforms.Resample(
resampler_22K = torchaudio.transforms.Resample(
self.input_sample_rate,
self.output_sample_rate,
22050,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
).to(device)
samples = [resampler(sample) for sample in voice_samples]
chunks = []
resampler_24K = torchaudio.transforms.Resample(
self.input_sample_rate,
24000,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
).to(device)
concat = torch.cat(samples, dim=-1)
chunk_size = concat.shape[-1]
if slices == 0:
slices = 1
elif max_chunk_size is not None and chunk_size > max_chunk_size:
slices = 1
while int(chunk_size / slices) > max_chunk_size:
slices = slices + 1
chunks = torch.chunk(concat, slices, dim=1)
chunk_size = chunks[0].shape[-1]
voice_samples = [migrate_to_device(v, device) for v in voice_samples]
auto_conds = []
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
auto_conds = torch.stack(auto_conds, dim=1)
diffusion_conds = []
if original_ar:
samples = [resampler_22K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate, cond_length=132300))
else:
samples = [resampler_22K(sample) for sample in voice_samples]
concat = torch.cat(samples, dim=-1)
chunk_size = concat.shape[-1]
if slices == 0:
slices = 1
elif max_chunk_size is not None and chunk_size > max_chunk_size:
slices = 1
while int(chunk_size / slices) > max_chunk_size:
slices = slices + 1
chunks = torch.chunk(concat, slices, dim=1)
chunk_size = chunks[0].shape[-1]
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
if original_diffusion:
samples = [resampler_24K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing diffusion conditioning latents..."):
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(migrate_to_device(sample, device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel)
else:
samples = [resampler_24K(sample) for sample in voice_samples]
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
diffusion_conds = []
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = migrate_to_device( self.diffusion, device )
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )