@ -448,13 +448,14 @@ class TextToSpeech:
if self . preloaded_tensors :
self . cvvp = migrate_to_device ( self . cvvp , self . device )
def get_conditioning_latents ( self , voice_samples , return_mels = False , verbose = False , slices = 1 , max_chunk_size = None , force_cpu = False ):
def get_conditioning_latents ( self , voice_samples , return_mels = False , verbose = False , slices = 1 , max_chunk_size = None , force_cpu = False , original_ar = False , original_diffusion = False ):
"""
Transforms one or more voice_samples into a tuple ( autoregressive_conditioning_latent , diffusion_conditioning_latent ) .
These are expressive learned latents that encode aspects of the provided clips like voice , intonation , and acoustic
properties .
: param voice_samples : List of 2 or more ~ 10 second reference clips , which should be torch tensors containing 22.05 kHz waveform data .
"""
with torch . no_grad ( ) :
# computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions
if get_device_name ( ) == " dml " :
@ -464,50 +465,72 @@ class TextToSpeech:
if not isinstance ( voice_samples , list ) :
voice_samples = [ voice_samples ]
voice_samples = [ migrate_to_device ( v , device ) for v in voice_samples ]
resampler_22K = torchaudio . transforms . Resample (
self . input_sample_rate ,
22050 ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
) . to ( device )
resampler = torchaudio . transforms . Resample (
resampler _24K = torchaudio . transforms . Resample (
self . input_sample_rate ,
self . output_sample_rate ,
24000 ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
) . to ( device )
samples = [ resampler ( sample ) for sample in voice_samples ]
chunks = [ ]
voice_samples = [ migrate_to_device ( v , device ) for v in voice_samples ]
concat = torch . cat ( samples , dim = - 1 )
chunk_size = concat . shape [ - 1 ]
auto_conds = [ ]
diffusion_conds = [ ]
if slices == 0 :
slices = 1
elif max_chunk_size is not None and chunk_size > max_chunk_size :
slices = 1
while int ( chunk_size / slices ) > max_chunk_size :
slices = slices + 1
if original_ar :
samples = [ resampler_22K ( sample ) for sample in voice_samples ]
for sample in tqdm ( samples , desc = " Computing AR conditioning latents... " ) :
auto_conds . append ( format_conditioning ( sample , device = device , sampling_rate = self . input_sample_rate , cond_length = 132300 ) )
else :
samples = [ resampler_22K ( sample ) for sample in voice_samples ]
concat = torch . cat ( samples , dim = - 1 )
chunk_size = concat . shape [ - 1 ]
if slices == 0 :
slices = 1
elif max_chunk_size is not None and chunk_size > max_chunk_size :
slices = 1
while int ( chunk_size / slices ) > max_chunk_size :
slices = slices + 1
chunks = torch . chunk ( concat , slices , dim = 1 )
chunk_size = chunks [ 0 ] . shape [ - 1 ]
for chunk in tqdm ( chunks , desc = " Computing AR conditioning latents... " ) :
auto_conds . append ( format_conditioning ( chunk , device = device , sampling_rate = self . input_sample_rate , cond_length = chunk_size ) )
chunks = torch . chunk ( concat , slices , dim = 1 )
chunk_size = chunks [ 0 ] . shape [ - 1 ]
if original_diffusion :
samples = [ resampler_24K ( sample ) for sample in voice_samples ]
for sample in tqdm ( samples , desc = " Computing diffusion conditioning latents... " ) :
sample = pad_or_truncate ( sample , 102400 )
cond_mel = wav_to_univnet_mel ( migrate_to_device ( sample , device ) , do_normalization = False , device = self . device )
diffusion_conds . append ( cond_mel )
else :
samples = [ resampler_24K ( sample ) for sample in voice_samples ]
for chunk in tqdm ( chunks , desc = " Computing diffusion conditioning latents... " ) :
check_for_kill_signal ( )
chunk = pad_or_truncate ( chunk , chunk_size )
cond_mel = wav_to_univnet_mel ( migrate_to_device ( chunk , device ) , do_normalization = False , device = device )
diffusion_conds . append ( cond_mel )
auto_conds = [ ]
for chunk in tqdm ( chunks , desc = " Computing AR conditioning latents... " ) :
auto_conds . append ( format_conditioning ( chunk , device = device , sampling_rate = self . input_sample_rate , cond_length = chunk_size ) )
auto_conds = torch . stack ( auto_conds , dim = 1 )
self . autoregressive = migrate_to_device ( self . autoregressive , device )
auto_latent = self . autoregressive . get_conditioning ( auto_conds )
self . autoregressive = migrate_to_device ( self . autoregressive , self . device if self . preloaded_tensors else ' cpu ' )
diffusion_conds = [ ]
for chunk in tqdm ( chunks , desc = " Computing diffusion conditioning latents... " ) :
check_for_kill_signal ( )
chunk = pad_or_truncate ( chunk , chunk_size )
cond_mel = wav_to_univnet_mel ( migrate_to_device ( chunk , device ) , do_normalization = False , device = device )
diffusion_conds . append ( cond_mel )
diffusion_conds = torch . stack ( diffusion_conds , dim = 1 )
diffusion_conds = torch . stack ( diffusion_conds , dim = 1 )
self . diffusion = migrate_to_device ( self . diffusion , device )
diffusion_latent = self . diffusion . get_conditioning ( diffusion_conds )
self . diffusion = migrate_to_device ( self . diffusion , self . device if self . preloaded_tensors else ' cpu ' )