forked from mrq/tortoise-tts
added flags to rever to default method of latent generation (separately for the AR and Diffusion latents, as some voices don't play nicely with the chunk-for-all method)
This commit is contained in:
parent
c90ee7c529
commit
5ff00bf3bf
|
@ -11,5 +11,5 @@ librosa==0.8.1
|
|||
torchaudio
|
||||
threadpoolctl
|
||||
appdirs
|
||||
numpy==1.23.5
|
||||
numpy<=1.23.5
|
||||
numba
|
|
@ -448,13 +448,14 @@ class TextToSpeech:
|
|||
if self.preloaded_tensors:
|
||||
self.cvvp = migrate_to_device( self.cvvp, self.device )
|
||||
|
||||
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False):
|
||||
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False, original_ar=False, original_diffusion=False):
|
||||
"""
|
||||
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
|
||||
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
|
||||
properties.
|
||||
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
# computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions
|
||||
if get_device_name() == "dml":
|
||||
|
@ -464,50 +465,72 @@ class TextToSpeech:
|
|||
if not isinstance(voice_samples, list):
|
||||
voice_samples = [voice_samples]
|
||||
|
||||
voice_samples = [migrate_to_device(v, device) for v in voice_samples]
|
||||
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
resampler_22K = torchaudio.transforms.Resample(
|
||||
self.input_sample_rate,
|
||||
self.output_sample_rate,
|
||||
22050,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
).to(device)
|
||||
|
||||
samples = [resampler(sample) for sample in voice_samples]
|
||||
chunks = []
|
||||
resampler_24K = torchaudio.transforms.Resample(
|
||||
self.input_sample_rate,
|
||||
24000,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
).to(device)
|
||||
|
||||
concat = torch.cat(samples, dim=-1)
|
||||
chunk_size = concat.shape[-1]
|
||||
|
||||
if slices == 0:
|
||||
slices = 1
|
||||
elif max_chunk_size is not None and chunk_size > max_chunk_size:
|
||||
slices = 1
|
||||
while int(chunk_size / slices) > max_chunk_size:
|
||||
slices = slices + 1
|
||||
|
||||
chunks = torch.chunk(concat, slices, dim=1)
|
||||
chunk_size = chunks[0].shape[-1]
|
||||
voice_samples = [migrate_to_device(v, device) for v in voice_samples]
|
||||
|
||||
auto_conds = []
|
||||
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
|
||||
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
|
||||
auto_conds = torch.stack(auto_conds, dim=1)
|
||||
diffusion_conds = []
|
||||
|
||||
if original_ar:
|
||||
samples = [resampler_22K(sample) for sample in voice_samples]
|
||||
for sample in tqdm(samples, desc="Computing AR conditioning latents..."):
|
||||
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate, cond_length=132300))
|
||||
else:
|
||||
samples = [resampler_22K(sample) for sample in voice_samples]
|
||||
concat = torch.cat(samples, dim=-1)
|
||||
chunk_size = concat.shape[-1]
|
||||
|
||||
if slices == 0:
|
||||
slices = 1
|
||||
elif max_chunk_size is not None and chunk_size > max_chunk_size:
|
||||
slices = 1
|
||||
while int(chunk_size / slices) > max_chunk_size:
|
||||
slices = slices + 1
|
||||
|
||||
chunks = torch.chunk(concat, slices, dim=1)
|
||||
chunk_size = chunks[0].shape[-1]
|
||||
|
||||
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
|
||||
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
|
||||
|
||||
|
||||
if original_diffusion:
|
||||
samples = [resampler_24K(sample) for sample in voice_samples]
|
||||
for sample in tqdm(samples, desc="Computing diffusion conditioning latents..."):
|
||||
sample = pad_or_truncate(sample, 102400)
|
||||
cond_mel = wav_to_univnet_mel(migrate_to_device(sample, device), do_normalization=False, device=self.device)
|
||||
diffusion_conds.append(cond_mel)
|
||||
else:
|
||||
samples = [resampler_24K(sample) for sample in voice_samples]
|
||||
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
|
||||
check_for_kill_signal()
|
||||
chunk = pad_or_truncate(chunk, chunk_size)
|
||||
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
|
||||
diffusion_conds.append(cond_mel)
|
||||
|
||||
auto_conds = torch.stack(auto_conds, dim=1)
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, device )
|
||||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
|
||||
|
||||
diffusion_conds = []
|
||||
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
|
||||
check_for_kill_signal()
|
||||
chunk = pad_or_truncate(chunk, chunk_size)
|
||||
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
self.diffusion = migrate_to_device( self.diffusion, device )
|
||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
||||
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )
|
||||
|
|
Loading…
Reference in New Issue
Block a user