2022-03-29 01:33:31 +00:00
import os
import random
2022-05-02 21:40:03 +00:00
import uuid
2023-02-09 22:17:57 +00:00
import gc
2022-05-17 18:11:18 +00:00
from time import time
2022-03-29 01:33:31 +00:00
from urllib import request
2023-03-03 06:30:58 +00:00
from urllib . request import ProxyHandler , build_opener , install_opener
2022-03-29 01:33:31 +00:00
import torch
import torch . nn . functional as F
import progressbar
2022-04-22 17:34:05 +00:00
import torchaudio
2022-03-29 01:33:31 +00:00
2022-05-01 22:24:24 +00:00
from tortoise . models . classifier import AudioMiniEncoderWithClassifierHead
from tortoise . models . diffusion_decoder import DiffusionTts
from tortoise . models . autoregressive import UnifiedVoice
2022-03-29 01:33:31 +00:00
from tqdm import tqdm
2022-05-01 22:24:24 +00:00
from tortoise . models . arch_util import TorchMelSpectrogram
from tortoise . models . clvp import CLVP
2022-05-25 10:22:50 +00:00
from tortoise . models . cvvp import CVVP
2022-05-02 21:40:03 +00:00
from tortoise . models . random_latent_generator import RandomLatentConverter
2022-05-01 22:24:24 +00:00
from tortoise . models . vocoder import UnivNetGenerator
2023-03-03 06:30:58 +00:00
from tortoise . models . bigvgan import BigVGAN
2022-05-01 22:24:24 +00:00
from tortoise . utils . audio import wav_to_univnet_mel , denormalize_tacotron_mel
from tortoise . utils . diffusion import SpacedDiffusion , space_timesteps , get_named_beta_schedule
from tortoise . utils . tokenizer import VoiceBpeTokenizer
2022-05-02 20:57:29 +00:00
from tortoise . utils . wav2vec_alignment import Wav2VecAlignment
2022-03-29 01:33:31 +00:00
2023-02-09 01:53:25 +00:00
from tortoise . utils . device import get_device , get_device_name , get_device_batch_size
2022-03-29 01:33:31 +00:00
pbar = None
2023-02-10 22:12:37 +00:00
STOP_SIGNAL = False
2023-02-16 23:18:02 +00:00
MODELS_DIR = os . environ . get ( ' TORTOISE_MODELS_DIR ' , os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/tortoise/ ' ) ) )
2022-05-25 10:22:50 +00:00
MODELS = {
' autoregressive.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth ' ,
' classifier.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth ' ,
' clvp2.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth ' ,
' cvvp.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth ' ,
' diffusion_decoder.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth ' ,
' vocoder.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth ' ,
' rlg_auto.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth ' ,
' rlg_diffuser.pth ' : ' https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth ' ,
2023-03-05 19:47:22 +00:00
' bigvgan_base_24khz_100band.pth ' : ' https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth ' ,
2022-05-25 10:22:50 +00:00
}
2022-04-21 22:06:43 +00:00
2023-03-02 00:44:42 +00:00
def hash_file ( path , algo = " md5 " , buffer_size = 0 ) :
import hashlib
hash = None
if algo == " md5 " :
hash = hashlib . md5 ( )
elif algo == " sha1 " :
hash = hashlib . sha1 ( )
else :
raise Exception ( f ' Unknown hash algorithm specified: { algo } ' )
if not os . path . exists ( path ) :
raise Exception ( f ' Path not found: { path } ' )
with open ( path , ' rb ' ) as f :
if buffer_size > 0 :
while True :
data = f . read ( buffer_size )
if not data :
break
hash . update ( data )
else :
hash . update ( f . read ( ) )
return " {0} " . format ( hash . hexdigest ( ) )
2023-02-24 23:10:04 +00:00
def check_for_kill_signal ( ) :
2023-02-10 22:12:37 +00:00
global STOP_SIGNAL
if STOP_SIGNAL :
STOP_SIGNAL = False
raise Exception ( " Kill signal detected " )
2023-02-24 23:10:04 +00:00
def tqdm_override ( arr , verbose = False , progress = None , desc = None ) :
check_for_kill_signal ( )
2023-02-03 04:56:30 +00:00
if verbose and desc is not None :
print ( desc )
2023-02-03 04:20:01 +00:00
if progress is None :
return tqdm ( arr , disable = not verbose )
2023-02-11 15:02:11 +00:00
return progress . tqdm ( arr , desc = f ' { progress . msg_prefix } { desc } ' if hasattr ( progress , ' msg_prefix ' ) else desc , track_tqdm = True )
2023-02-03 04:20:01 +00:00
2022-04-26 16:24:03 +00:00
def download_models ( specific_models = None ) :
2022-04-21 22:06:43 +00:00
"""
Call to download all the models that Tortoise uses .
"""
2023-02-10 22:12:37 +00:00
2022-05-19 11:31:02 +00:00
os . makedirs ( MODELS_DIR , exist_ok = True )
2022-05-25 10:22:50 +00:00
2022-03-29 01:33:31 +00:00
def show_progress ( block_num , block_size , total_size ) :
global pbar
if pbar is None :
pbar = progressbar . ProgressBar ( maxval = total_size )
pbar . start ( )
downloaded = block_num * block_size
if downloaded < total_size :
pbar . update ( downloaded )
else :
pbar . finish ( )
pbar = None
for model_name , url in MODELS . items ( ) :
2022-04-26 16:24:03 +00:00
if specific_models is not None and model_name not in specific_models :
continue
2022-05-19 11:31:02 +00:00
model_path = os . path . join ( MODELS_DIR , model_name )
if os . path . exists ( model_path ) :
2022-03-29 01:33:31 +00:00
continue
print ( f ' Downloading { model_name } from { url } ... ' )
2023-03-03 06:30:58 +00:00
proxy = ProxyHandler ( { } )
opener = build_opener ( proxy )
opener . addheaders = [ ( ' User-Agent ' , ' mrq/AI-Voice-Cloning ' ) ]
install_opener ( opener )
2022-05-19 11:31:02 +00:00
request . urlretrieve ( url , model_path , show_progress )
2022-03-29 01:33:31 +00:00
print ( ' Done. ' )
2022-05-25 10:22:50 +00:00
def get_model_path ( model_name , models_dir = MODELS_DIR ) :
"""
Get path to given model , download it if it doesn ' t exist.
"""
if model_name not in MODELS :
raise ValueError ( f ' Model { model_name } not found in available models. ' )
model_path = os . path . join ( models_dir , model_name )
if not os . path . exists ( model_path ) and models_dir == MODELS_DIR :
download_models ( [ model_name ] )
return model_path
2022-04-01 20:15:17 +00:00
def pad_or_truncate ( t , length ) :
2022-04-21 22:06:43 +00:00
"""
Utility function for forcing < t > to have the specified sequence length , whether by clipping it or padding it with 0 s .
"""
2022-04-01 20:15:17 +00:00
if t . shape [ - 1 ] == length :
return t
elif t . shape [ - 1 ] < length :
return F . pad ( t , ( 0 , length - t . shape [ - 1 ] ) )
else :
return t [ . . . , : length ]
2022-03-29 19:59:39 +00:00
def load_discrete_vocoder_diffuser ( trained_diffusion_steps = 4000 , desired_diffusion_steps = 200 , cond_free = True , cond_free_k = 1 ) :
2022-03-29 01:33:31 +00:00
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder .
"""
return SpacedDiffusion ( use_timesteps = space_timesteps ( trained_diffusion_steps , [ desired_diffusion_steps ] ) , model_mean_type = ' epsilon ' ,
model_var_type = ' learned_range ' , loss_type = ' mse ' , betas = get_named_beta_schedule ( ' linear ' , trained_diffusion_steps ) ,
2022-03-29 19:59:39 +00:00
conditioning_free = cond_free , conditioning_free_k = cond_free_k )
2022-03-29 01:33:31 +00:00
2023-02-07 18:34:29 +00:00
def format_conditioning ( clip , cond_length = 132300 , device = ' cuda ' , sampling_rate = 22050 ) :
2022-04-21 22:06:43 +00:00
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models .
"""
2022-03-29 01:33:31 +00:00
gap = clip . shape [ - 1 ] - cond_length
if gap < 0 :
clip = F . pad ( clip , pad = ( 0 , abs ( gap ) ) )
elif gap > 0 :
rand_start = random . randint ( 0 , gap )
clip = clip [ : , rand_start : rand_start + cond_length ]
2023-02-07 20:55:56 +00:00
mel_clip = TorchMelSpectrogram ( sampling_rate = sampling_rate ) ( clip . unsqueeze ( 0 ) ) . squeeze ( 0 )
2022-06-11 11:03:14 +00:00
return mel_clip . unsqueeze ( 0 ) . to ( device )
2022-03-29 01:33:31 +00:00
2022-04-15 03:50:57 +00:00
def fix_autoregressive_output ( codes , stop_token , complain = True ) :
2022-03-29 01:33:31 +00:00
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
trained on and what the autoregressive code generator creates ( which has no padding or end ) .
This is highly specific to the DVAE being used , so this particular coding will not necessarily work if used with
a different DVAE . This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
and copying out the last few codes .
Failing to do this padding will produce speech with a harsh end that sounds like " BLAH " or similar .
"""
# Strip off the autoregressive stop token and add padding.
stop_token_indices = ( codes == stop_token ) . nonzero ( )
if len ( stop_token_indices ) == 0 :
2022-04-15 03:50:57 +00:00
if complain :
2022-05-03 00:00:57 +00:00
print ( " No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
" too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
" try breaking up your input text. " )
2022-03-29 01:33:31 +00:00
return codes
else :
codes [ stop_token_indices ] = 83
stm = stop_token_indices . min ( ) . item ( )
codes [ stm : ] = 83
if stm - 3 < codes . shape [ 0 ] :
codes [ - 3 ] = 45
codes [ - 2 ] = 45
codes [ - 1 ] = 248
return codes
2023-02-07 18:34:29 +00:00
def do_spectrogram_diffusion ( diffusion_model , diffuser , latents , conditioning_latents , temperature = 1 , verbose = True , progress = None , desc = None , sampler = " P " , input_sample_rate = 22050 , output_sample_rate = 24000 ) :
2022-03-29 01:33:31 +00:00
"""
2022-04-01 20:15:17 +00:00
Uses the specified diffusion model to convert discrete codes into a spectrogram .
2022-03-29 01:33:31 +00:00
"""
with torch . no_grad ( ) :
2023-02-07 18:34:29 +00:00
output_seq_len = latents . shape [ 1 ] * 4 * output_sample_rate / / input_sample_rate # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
2022-04-22 17:34:05 +00:00
output_shape = ( latents . shape [ 0 ] , 100 , output_seq_len )
2022-05-01 23:25:18 +00:00
precomputed_embeddings = diffusion_model . timestep_independent ( latents , conditioning_latents , output_seq_len , False )
2022-03-29 19:59:39 +00:00
2022-04-22 17:34:05 +00:00
noise = torch . randn ( output_shape , device = latents . device ) * temperature
2023-02-05 01:28:31 +00:00
2023-02-05 17:59:13 +00:00
diffuser . sampler = sampler . lower ( )
2023-02-05 20:10:40 +00:00
mel = diffuser . sample_loop ( diffusion_model , output_shape , noise = noise ,
2022-04-25 22:59:04 +00:00
model_kwargs = { ' precomputed_aligned_embeddings ' : precomputed_embeddings } ,
2023-02-03 04:56:30 +00:00
verbose = verbose , progress = progress , desc = desc )
2023-02-05 01:28:31 +00:00
2023-02-09 05:05:21 +00:00
mel = denormalize_tacotron_mel ( mel ) [ : , : , : output_seq_len ]
if get_device_name ( ) == " dml " :
mel = mel . cpu ( )
return mel
2022-03-29 01:33:31 +00:00
2022-04-26 16:24:03 +00:00
def classify_audio_clip ( clip ) :
"""
Returns whether or not Tortoises ' classifier thinks the given clip came from Tortoise.
: param clip : torch tensor containing audio waveform data ( get it from load_audio )
: return : True if the clip was classified as coming from Tortoise and false if it was classified as real .
"""
classifier = AudioMiniEncoderWithClassifierHead ( 2 , spec_dim = 1 , embedding_dim = 512 , depth = 5 , downsample_factor = 4 ,
resnet_blocks = 2 , attn_blocks = 4 , num_attn_heads = 4 , base_channels = 32 ,
dropout = 0 , kernel_size = 5 , distribute_zero_label = False )
2022-05-25 10:22:50 +00:00
classifier . load_state_dict ( torch . load ( get_model_path ( ' classifier.pth ' ) , map_location = torch . device ( ' cpu ' ) ) )
2022-04-26 16:24:03 +00:00
clip = clip . cpu ( ) . unsqueeze ( 0 )
results = F . softmax ( classifier ( clip ) , dim = - 1 )
return results [ 0 ] [ 0 ]
2022-03-29 01:33:31 +00:00
class TextToSpeech :
2022-04-22 17:34:05 +00:00
"""
Main entry point into Tortoise .
"""
2022-05-02 20:57:29 +00:00
2023-03-03 06:30:58 +00:00
def __init__ ( self , autoregressive_batch_size = None , models_dir = MODELS_DIR , enable_redaction = True , device = None , minor_optimizations = True , input_sample_rate = 22050 , output_sample_rate = 24000 , autoregressive_model_path = None , use_bigvgan = True ) :
2022-05-02 20:57:29 +00:00
"""
Constructor
: param autoregressive_batch_size : Specifies how many samples to generate per batch . Lower this if you are seeing
GPU OOM errors . Larger numbers generates slightly faster .
: param models_dir : Where model weights are stored . This should only be specified if you are providing your own
models , otherwise use the defaults .
: param enable_redaction : When true , text enclosed in brackets are automatically redacted from the spoken output
( but are still rendered by the model ) . This can be used for prompt engineering .
2022-05-02 21:40:03 +00:00
Default is true .
2022-06-11 11:03:14 +00:00
: param device : Device to use when running the model . If omitted , the device will be automatically chosen .
2023-02-08 23:03:52 +00:00
"""
2023-02-07 13:51:05 +00:00
if device is None :
2023-02-09 01:53:25 +00:00
device = get_device ( verbose = True )
2023-02-07 13:51:05 +00:00
2023-02-07 18:34:29 +00:00
self . input_sample_rate = input_sample_rate
self . output_sample_rate = output_sample_rate
2023-02-04 01:50:57 +00:00
self . minor_optimizations = minor_optimizations
2023-02-07 18:34:29 +00:00
2023-02-12 14:46:21 +00:00
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
self . preloaded_tensors = minor_optimizations
self . use_kv_cache = minor_optimizations
if get_device_name ( ) == " dml " : # does not work with DirectML
print ( " KV caching requested but not supported with the DirectML backend, disabling... " )
self . use_kv_cache = False
2022-05-25 10:22:50 +00:00
self . models_dir = models_dir
2023-02-09 01:53:25 +00:00
self . autoregressive_batch_size = get_device_batch_size ( ) if autoregressive_batch_size is None or autoregressive_batch_size == 0 else autoregressive_batch_size
2022-05-02 20:57:29 +00:00
self . enable_redaction = enable_redaction
2023-02-07 13:51:05 +00:00
self . device = device
2022-05-02 20:57:29 +00:00
if self . enable_redaction :
2023-02-12 20:52:04 +00:00
self . aligner = Wav2VecAlignment ( device = ' cpu ' if get_device_name ( ) == " dml " else self . device )
2022-05-02 20:57:29 +00:00
2022-03-29 01:33:31 +00:00
self . tokenizer = VoiceBpeTokenizer ( )
2023-02-18 14:08:45 +00:00
self . autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os . path . exists ( autoregressive_model_path ) else get_model_path ( ' autoregressive.pth ' , models_dir )
2023-03-02 00:44:42 +00:00
self . autoregressive_model_hash = hash_file ( self . autoregressive_model_path )
2023-02-18 14:08:45 +00:00
2022-05-01 20:51:44 +00:00
if os . path . exists ( f ' { models_dir } /autoregressive.ptt ' ) :
# Assume this is a traced directory.
self . autoregressive = torch . jit . load ( f ' { models_dir } /autoregressive.ptt ' )
self . diffusion = torch . jit . load ( f ' { models_dir } /diffusion_decoder.ptt ' )
else :
self . autoregressive = UnifiedVoice ( max_mel_tokens = 604 , max_text_tokens = 402 , max_conditioning_inputs = 2 , layers = 30 ,
model_dim = 1024 ,
heads = 16 , number_text_tokens = 255 , start_text_token = 255 , checkpointing = False ,
2022-05-02 22:44:47 +00:00
train_solo_embeddings = False ) . cpu ( ) . eval ( )
2023-02-18 14:08:45 +00:00
self . autoregressive . load_state_dict ( torch . load ( self . autoregressive_model_path ) )
2023-02-12 14:46:21 +00:00
self . autoregressive . post_init_gpt2_config ( kv_cache = self . use_kv_cache )
2022-05-01 20:51:44 +00:00
self . diffusion = DiffusionTts ( model_channels = 1024 , num_layers = 10 , in_channels = 100 , out_channels = 200 ,
in_latent_channels = 1024 , in_tokens = 8193 , dropout = 0 , use_fp16 = False , num_heads = 16 ,
layer_drop = 0 , unconditioned_percentage = 0 ) . cpu ( ) . eval ( )
2022-05-25 10:22:50 +00:00
self . diffusion . load_state_dict ( torch . load ( get_model_path ( ' diffusion_decoder.pth ' , models_dir ) ) )
2022-04-13 23:03:36 +00:00
2023-02-04 01:50:57 +00:00
2022-05-12 19:23:03 +00:00
self . clvp = CLVP ( dim_text = 768 , dim_speech = 768 , dim_latent = 768 , num_text_tokens = 256 , text_enc_depth = 20 ,
text_seq_len = 350 , text_heads = 12 ,
num_speech_tokens = 8192 , speech_enc_depth = 20 , speech_heads = 12 , speech_seq_len = 430 ,
2022-04-18 20:47:44 +00:00
use_xformers = True ) . cpu ( ) . eval ( )
2022-05-25 10:22:50 +00:00
self . clvp . load_state_dict ( torch . load ( get_model_path ( ' clvp2.pth ' , models_dir ) ) )
self . cvvp = None # CVVP model is only loaded if used.
2022-04-18 20:47:44 +00:00
2023-03-03 06:30:58 +00:00
if use_bigvgan :
2023-03-03 13:53:21 +00:00
# credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52
2023-03-03 06:30:58 +00:00
self . vocoder = BigVGAN ( ) . cpu ( )
2023-03-03 13:53:21 +00:00
self . vocoder . load_state_dict ( torch . load ( get_model_path ( ' bigvgan_base_24khz_100band.pth ' , models_dir ) , map_location = torch . device ( ' cpu ' ) ) [ ' generator ' ] )
2023-03-03 06:30:58 +00:00
else :
self . vocoder = UnivNetGenerator ( ) . cpu ( )
self . vocoder . load_state_dict ( torch . load ( get_model_path ( ' vocoder.pth ' , models_dir ) , map_location = torch . device ( ' cpu ' ) ) [ ' model_g ' ] )
2022-03-29 01:33:31 +00:00
self . vocoder . eval ( inference = True )
2022-05-02 21:40:03 +00:00
# Random latent generators (RLGs) are loaded lazily.
self . rlg_auto = None
self . rlg_diffusion = None
2023-02-12 14:46:21 +00:00
if self . preloaded_tensors :
2023-02-04 01:50:57 +00:00
self . autoregressive = self . autoregressive . to ( self . device )
self . diffusion = self . diffusion . to ( self . device )
self . clvp = self . clvp . to ( self . device )
self . vocoder = self . vocoder . to ( self . device )
2023-02-18 14:08:45 +00:00
def load_autoregressive_model ( self , autoregressive_model_path ) :
previous_path = self . autoregressive_model_path
2023-02-19 01:54:21 +00:00
self . autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os . path . exists ( autoregressive_model_path ) else get_model_path ( ' autoregressive.pth ' , self . models_dir )
2023-03-02 00:44:42 +00:00
self . autoregressive_model_hash = hash_file ( self . autoregressive_model_path )
2023-02-18 14:08:45 +00:00
del self . autoregressive
self . autoregressive = UnifiedVoice ( max_mel_tokens = 604 , max_text_tokens = 402 , max_conditioning_inputs = 2 , layers = 30 ,
model_dim = 1024 ,
heads = 16 , number_text_tokens = 255 , start_text_token = 255 , checkpointing = False ,
train_solo_embeddings = False ) . cpu ( ) . eval ( )
self . autoregressive . load_state_dict ( torch . load ( self . autoregressive_model_path ) )
self . autoregressive . post_init_gpt2_config ( kv_cache = self . use_kv_cache )
if self . preloaded_tensors :
self . autoregressive = self . autoregressive . to ( self . device )
2023-03-02 00:44:42 +00:00
2023-02-18 14:08:45 +00:00
return previous_path != self . autoregressive_model_path
2022-05-25 10:22:50 +00:00
def load_cvvp ( self ) :
""" Load CVVP model. """
self . cvvp = CVVP ( model_dim = 512 , transformer_heads = 8 , dropout = 0 , mel_codes = 8192 , conditioning_enc_depth = 8 , cond_mask_percentage = 0 ,
speech_enc_depth = 8 , speech_mask_percentage = 0 , latent_multiplier = 1 ) . cpu ( ) . eval ( )
self . cvvp . load_state_dict ( torch . load ( get_model_path ( ' cvvp.pth ' , self . models_dir ) ) )
2023-02-04 01:50:57 +00:00
2023-02-12 14:46:21 +00:00
if self . preloaded_tensors :
2023-02-04 01:50:57 +00:00
self . cvvp = self . cvvp . to ( self . device )
2022-05-25 10:22:50 +00:00
2023-02-15 05:01:40 +00:00
def get_conditioning_latents ( self , voice_samples , return_mels = False , verbose = False , progress = None , slices = 1 , max_chunk_size = None , force_cpu = False ) :
2022-05-01 23:25:18 +00:00
"""
Transforms one or more voice_samples into a tuple ( autoregressive_conditioning_latent , diffusion_conditioning_latent ) .
These are expressive learned latents that encode aspects of the provided clips like voice , intonation , and acoustic
properties .
: param voice_samples : List of 2 or more ~ 10 second reference clips , which should be torch tensors containing 22.05 kHz waveform data .
"""
2022-05-11 13:35:11 +00:00
with torch . no_grad ( ) :
2023-02-12 14:46:21 +00:00
# computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions
2023-02-15 05:01:40 +00:00
if get_device_name ( ) == " dml " :
force_cpu = True
device = torch . device ( ' cpu ' ) if force_cpu else self . device
2023-02-09 01:53:25 +00:00
2023-02-12 17:44:39 +00:00
if not isinstance ( voice_samples , list ) :
voice_samples = [ voice_samples ]
2023-02-09 01:53:25 +00:00
voice_samples = [ v . to ( device ) for v in voice_samples ]
2022-05-11 13:35:11 +00:00
2023-02-12 17:44:39 +00:00
resampler = torchaudio . transforms . Resample (
self . input_sample_rate ,
self . output_sample_rate ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
)
2023-02-12 20:52:04 +00:00
samples = [ ]
auto_conds = [ ]
2023-02-12 17:44:39 +00:00
for sample in voice_samples :
2023-02-12 20:52:04 +00:00
auto_conds . append ( format_conditioning ( sample , device = device , sampling_rate = self . input_sample_rate ) )
2023-02-12 17:44:39 +00:00
samples . append ( resampler ( sample . cpu ( ) ) . to ( device ) ) # icky no good, easier to do the resampling on CPU than figure out how to do it on GPU
2023-02-12 14:46:21 +00:00
2023-02-12 20:52:04 +00:00
auto_conds = torch . stack ( auto_conds , dim = 1 )
2023-02-12 17:44:39 +00:00
self . autoregressive = self . autoregressive . to ( device )
2023-02-12 14:46:21 +00:00
auto_latent = self . autoregressive . get_conditioning ( auto_conds )
if self . preloaded_tensors :
self . autoregressive = self . autoregressive . to ( self . device )
2023-02-12 17:44:39 +00:00
else :
self . autoregressive = self . autoregressive . cpu ( )
2023-02-04 01:50:57 +00:00
2022-05-11 13:35:11 +00:00
diffusion_conds = [ ]
2023-02-06 03:44:34 +00:00
chunks = [ ]
2023-02-12 17:44:39 +00:00
2023-02-15 04:39:31 +00:00
concat = torch . cat ( samples , dim = - 1 )
chunk_size = concat . shape [ - 1 ]
2023-02-12 17:44:39 +00:00
2023-02-15 04:39:31 +00:00
if slices == 0 :
slices = 1
2023-02-06 03:44:34 +00:00
else :
2023-02-12 17:44:39 +00:00
if max_chunk_size is not None and chunk_size > max_chunk_size :
2023-02-15 04:39:31 +00:00
slices = 1
while int ( chunk_size / slices ) > max_chunk_size :
slices = slices + 1
chunks = torch . chunk ( concat , slices , dim = 1 )
chunk_size = chunks [ 0 ] . shape [ - 1 ]
2023-02-14 16:47:34 +00:00
# expand / truncate samples to match the common size
# required, as tensors need to be of the same length
2023-02-06 03:44:34 +00:00
for chunk in tqdm_override ( chunks , verbose = verbose , progress = progress , desc = " Computing conditioning latents... " ) :
2023-02-24 23:10:04 +00:00
check_for_kill_signal ( )
2023-02-06 05:10:07 +00:00
chunk = pad_or_truncate ( chunk , chunk_size )
2023-02-09 01:53:25 +00:00
cond_mel = wav_to_univnet_mel ( chunk . to ( device ) , do_normalization = False , device = device )
2023-02-06 03:44:34 +00:00
diffusion_conds . append ( cond_mel )
2022-05-11 13:35:11 +00:00
2023-02-05 23:25:41 +00:00
diffusion_conds = torch . stack ( diffusion_conds , dim = 1 )
2023-02-04 01:50:57 +00:00
2023-02-12 17:44:39 +00:00
self . diffusion = self . diffusion . to ( device )
2023-02-12 14:46:21 +00:00
diffusion_latent = self . diffusion . get_conditioning ( diffusion_conds )
if self . preloaded_tensors :
self . diffusion = self . diffusion . to ( self . device )
2023-02-12 17:44:39 +00:00
else :
self . diffusion = self . diffusion . cpu ( )
2022-04-11 05:19:15 +00:00
2023-02-15 05:01:40 +00:00
2022-05-02 22:44:47 +00:00
if return_mels :
return auto_latent , diffusion_latent , auto_conds , diffusion_conds
else :
return auto_latent , diffusion_latent
2022-05-02 21:40:03 +00:00
def get_random_conditioning_latents ( self ) :
# Lazy-load the RLG models.
if self . rlg_auto is None :
self . rlg_auto = RandomLatentConverter ( 1024 ) . eval ( )
2022-05-25 10:22:50 +00:00
self . rlg_auto . load_state_dict ( torch . load ( get_model_path ( ' rlg_auto.pth ' , self . models_dir ) , map_location = torch . device ( ' cpu ' ) ) )
2022-05-02 21:40:03 +00:00
self . rlg_diffusion = RandomLatentConverter ( 2048 ) . eval ( )
2022-05-25 10:22:50 +00:00
self . rlg_diffusion . load_state_dict ( torch . load ( get_model_path ( ' rlg_diffuser.pth ' , self . models_dir ) , map_location = torch . device ( ' cpu ' ) ) )
2022-05-02 21:40:03 +00:00
with torch . no_grad ( ) :
2022-05-03 00:00:57 +00:00
return self . rlg_auto ( torch . tensor ( [ 0.0 ] ) ) , self . rlg_diffusion ( torch . tensor ( [ 0.0 ] ) )
def tts_with_preset ( self , text , preset = ' fast ' , * * kwargs ) :
"""
Calls TTS with one of a set of preset generation parameters . Options :
' ultra_fast ' : Produces speech at a speed which belies the name of this repo . ( Not really , but it ' s definitely fastest).
' fast ' : Decent quality speech at a decent inference rate . A good choice for mass inference .
' standard ' : Very good quality . This is generally about as good as you are going to get .
' high_quality ' : Use if you want the absolute best . This is not really worth the compute , though .
"""
# Use generally found best tuning knobs for generation.
2022-05-28 05:25:23 +00:00
settings = { ' temperature ' : .8 , ' length_penalty ' : 1.0 , ' repetition_penalty ' : 2.0 ,
' top_p ' : .8 ,
' cond_free_k ' : 2.0 , ' diffusion_temperature ' : 1.0 }
2022-05-03 00:00:57 +00:00
# Presets are defined here.
presets = {
2022-05-04 03:01:26 +00:00
' ultra_fast ' : { ' num_autoregressive_samples ' : 16 , ' diffusion_iterations ' : 30 , ' cond_free ' : False } ,
' fast ' : { ' num_autoregressive_samples ' : 96 , ' diffusion_iterations ' : 80 } ,
' standard ' : { ' num_autoregressive_samples ' : 256 , ' diffusion_iterations ' : 200 } ,
' high_quality ' : { ' num_autoregressive_samples ' : 256 , ' diffusion_iterations ' : 400 } ,
2022-05-03 00:00:57 +00:00
}
2022-05-28 05:25:23 +00:00
settings . update ( presets [ preset ] )
settings . update ( kwargs ) # allow overriding of preset settings with kwargs
return self . tts ( text , * * settings )
2022-05-01 23:25:18 +00:00
2022-05-17 18:11:18 +00:00
def tts ( self , text , voice_samples = None , conditioning_latents = None , k = 1 , verbose = True , use_deterministic_seed = None ,
return_deterministic_state = False ,
2022-03-29 19:59:39 +00:00
# autoregressive generation parameters follow
2022-04-18 16:22:36 +00:00
num_autoregressive_samples = 512 , temperature = .8 , length_penalty = 1 , repetition_penalty = 2.0 , top_p = .8 , max_mel_tokens = 500 ,
2023-02-06 22:31:06 +00:00
sample_batch_size = None ,
2022-05-25 10:22:50 +00:00
# CVVP parameters follow
cvvp_amount = .0 ,
2022-03-29 19:59:39 +00:00
# diffusion generation parameters follow
2022-04-18 15:22:15 +00:00
diffusion_iterations = 100 , cond_free = True , cond_free_k = 2 , diffusion_temperature = 1.0 ,
2023-02-05 01:28:31 +00:00
diffusion_sampler = " P " ,
2023-02-05 14:45:51 +00:00
breathing_room = 8 ,
2023-02-05 23:25:41 +00:00
half_p = False ,
2023-02-03 04:20:01 +00:00
progress = None ,
2022-04-18 15:22:15 +00:00
* * hf_generate_kwargs ) :
2022-04-25 22:59:04 +00:00
"""
Produces an audio clip of the given text being spoken with the given reference voice .
: param text : Text to be spoken .
: param voice_samples : List of 2 or more ~ 10 second reference clips which should be torch tensors containing 22.05 kHz waveform data .
2022-05-01 23:25:18 +00:00
: param conditioning_latents : A tuple of ( autoregressive_conditioning_latent , diffusion_conditioning_latent ) , which
can be provided in lieu of voice_samples . This is ignored unless voice_samples = None .
Conditioning latents can be retrieved via get_conditioning_latents ( ) .
2022-05-17 18:21:25 +00:00
: param k : The number of returned clips . The most likely ( as determined by Tortoises ' CLVP model) clips are returned.
2022-04-25 22:59:04 +00:00
: param verbose : Whether or not to print log messages indicating the progress of creating a clip . Default = true .
~ ~ AUTOREGRESSIVE KNOBS ~ ~
2022-05-17 18:21:25 +00:00
: param num_autoregressive_samples : Number of samples taken from the autoregressive model , all of which are filtered using CLVP .
2022-04-25 22:59:04 +00:00
As Tortoise is a probabilistic model , more samples means a higher probability of creating something " great " .
: param temperature : The softmax temperature of the autoregressive model .
: param length_penalty : A length penalty applied to the autoregressive decoder . Higher settings causes the model to produce more terse outputs .
: param repetition_penalty : A penalty that prevents the autoregressive decoder from repeating itself during decoding . Can be used to reduce the incidence
of long silences or " uhhhhhhs " , etc .
: param top_p : P value used in nucleus sampling . ( 0 , 1 ] . Lower values mean the decoder produces more " likely " ( aka boring ) outputs .
: param max_mel_tokens : Restricts the output length . ( 0 , 600 ] integer . Each unit is 1 / 20 of a second .
: param typical_sampling : Turns typical sampling on or off . This sampling mode is discussed in this paper : https : / / arxiv . org / abs / 2202.00666
I was interested in the premise , but the results were not as good as I was hoping . This is off by default , but
could use some tuning .
: param typical_mass : The typical_mass parameter from the typical_sampling algorithm .
2022-05-25 10:22:50 +00:00
~ ~ CLVP - CVVP KNOBS ~ ~
: param cvvp_amount : Controls the influence of the CVVP model in selecting the best output from the autoregressive model .
[ 0 , 1 ] . Values closer to 1 mean the CVVP model is more important , 0 disables the CVVP model .
2022-04-25 22:59:04 +00:00
~ ~ DIFFUSION KNOBS ~ ~
: param diffusion_iterations : Number of diffusion steps to perform . [ 0 , 4000 ] . More steps means the network has more chances to iteratively refine
the output , which should theoretically mean a higher quality output . Generally a value above 250 is not noticeably better ,
however .
: param cond_free : Whether or not to perform conditioning - free diffusion . Conditioning - free diffusion performs two forward passes for
each diffusion step : one with the outputs of the autoregressive model and one with no conditioning priors . The output
of the two is blended according to the cond_free_k value below . Conditioning - free diffusion is the real deal , and
dramatically improves realism .
: param cond_free_k : Knob that determines how to balance the conditioning free signal with the conditioning - present signal . [ 0 , inf ] .
As cond_free_k increases , the output becomes dominated by the conditioning - free signal .
Formula is : output = cond_present_output * ( cond_free_k + 1 ) - cond_absenct_output * cond_free_k
: param diffusion_temperature : Controls the variance of the noise fed into the diffusion model . [ 0 , 1 ] . Values at 0
are the " mean " prediction of the diffusion network and will sound bland and smeared .
~ ~ OTHER STUFF ~ ~
: param hf_generate_kwargs : The huggingface Transformers generate API is used for the autoregressive transformer .
Extra keyword args fed to this function get forwarded directly to that API . Documentation
here : https : / / huggingface . co / docs / transformers / internal / generation_utils
: return : Generated audio clip ( s ) as a torch tensor . Shape 1 , S if k = 1 else , ( k , 1 , S ) where S is the sample length .
Sample rate is 24 kHz .
"""
2023-02-12 14:46:21 +00:00
if get_device_name ( ) == " dml " and half_p :
print ( " Float16 requested but not supported with the DirectML backend, disabling... " )
2023-02-09 05:05:21 +00:00
half_p = False
2023-02-06 03:44:34 +00:00
self . diffusion . enable_fp16 = half_p
2022-05-17 18:11:18 +00:00
deterministic_seed = self . deterministic_state ( seed = use_deterministic_seed )
2022-06-11 11:03:14 +00:00
text_tokens = torch . IntTensor ( self . tokenizer . encode ( text ) ) . unsqueeze ( 0 ) . to ( self . device )
2022-05-02 21:40:03 +00:00
text_tokens = F . pad ( text_tokens , ( 0 , 1 ) ) # This may not be necessary.
assert text_tokens . shape [ - 1 ] < 400 , ' Too much text provided. Break the text up into separate segments and re-try inference. '
2022-03-29 01:33:31 +00:00
2022-05-02 21:40:03 +00:00
auto_conds = None
2022-05-01 23:25:18 +00:00
if voice_samples is not None :
2023-02-06 03:44:34 +00:00
auto_conditioning , diffusion_conditioning , auto_conds , _ = self . get_conditioning_latents ( voice_samples , return_mels = True , verbose = True )
2022-05-02 21:40:03 +00:00
elif conditioning_latents is not None :
2023-02-07 20:55:56 +00:00
latent_tuple = conditioning_latents
if len ( latent_tuple ) == 2 :
auto_conditioning , diffusion_conditioning = conditioning_latents
else :
auto_conditioning , diffusion_conditioning , auto_conds , _ = conditioning_latents
2022-05-02 21:40:03 +00:00
else :
auto_conditioning , diffusion_conditioning = self . get_random_conditioning_latents ( )
2023-02-11 20:34:12 +00:00
2022-03-29 19:59:39 +00:00
diffuser = load_discrete_vocoder_diffuser ( desired_diffusion_steps = diffusion_iterations , cond_free = cond_free , cond_free_k = cond_free_k )
2023-02-06 05:10:07 +00:00
2023-02-09 01:53:25 +00:00
self . autoregressive_batch_size = get_device_batch_size ( ) if sample_batch_size is None or sample_batch_size == 0 else sample_batch_size
2023-02-06 22:31:06 +00:00
2022-03-29 01:33:31 +00:00
with torch . no_grad ( ) :
samples = [ ]
num_batches = num_autoregressive_samples / / self . autoregressive_batch_size
2023-02-09 05:05:21 +00:00
if num_autoregressive_samples < self . autoregressive_batch_size :
num_autoregressive_samples = 1
2022-03-29 01:33:31 +00:00
stop_mel_token = self . autoregressive . stop_mel_token
2022-04-11 01:29:32 +00:00
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
2023-02-12 14:46:21 +00:00
self . autoregressive = self . autoregressive . to ( self . device )
auto_conditioning = auto_conditioning . to ( self . device )
text_tokens = text_tokens . to ( self . device )
2023-02-05 23:25:41 +00:00
with torch . autocast ( device_type = ' cuda ' , dtype = torch . float16 , enabled = half_p ) :
for b in tqdm_override ( range ( num_batches ) , verbose = verbose , progress = progress , desc = " Generating autoregressive samples " ) :
2023-02-24 23:10:04 +00:00
check_for_kill_signal ( )
2023-02-05 23:25:41 +00:00
codes = self . autoregressive . inference_speech ( auto_conditioning , text_tokens ,
do_sample = True ,
top_p = top_p ,
temperature = temperature ,
num_return_sequences = self . autoregressive_batch_size ,
length_penalty = length_penalty ,
repetition_penalty = repetition_penalty ,
max_generate_length = max_mel_tokens ,
* * hf_generate_kwargs )
padding_needed = max_mel_tokens - codes . shape [ 1 ]
codes = F . pad ( codes , ( 0 , padding_needed ) , value = stop_mel_token )
samples . append ( codes )
2022-03-29 01:33:31 +00:00
2023-02-12 14:46:21 +00:00
if not self . preloaded_tensors :
self . autoregressive = self . autoregressive . cpu ( )
auto_conditioning = auto_conditioning . cpu ( )
2022-05-25 10:22:50 +00:00
clip_results = [ ]
2023-02-04 01:50:57 +00:00
2023-02-12 14:46:21 +00:00
if auto_conds is not None :
auto_conds = auto_conds . to ( self . device )
2023-02-05 23:25:41 +00:00
with torch . autocast ( device_type = ' cuda ' , dtype = torch . float16 , enabled = half_p ) :
2023-02-04 01:50:57 +00:00
if not self . minor_optimizations :
2023-02-05 23:25:41 +00:00
self . autoregressive = self . autoregressive . cpu ( )
self . clvp = self . clvp . to ( self . device )
if cvvp_amount > 0 :
if self . cvvp is None :
self . load_cvvp ( )
if not self . minor_optimizations :
self . cvvp = self . cvvp . to ( self . device )
desc = " Computing best candidates "
if verbose :
if self . cvvp is None :
desc = " Computing best candidates using CLVP "
else :
desc = f " Computing best candidates using CLVP { ( ( 1 - cvvp_amount ) * 100 ) : 2.0f } % and CVVP { ( cvvp_amount * 100 ) : 2.0f } % "
for batch in tqdm_override ( samples , verbose = verbose , progress = progress , desc = desc ) :
2023-02-24 23:10:04 +00:00
check_for_kill_signal ( )
2023-02-05 23:25:41 +00:00
for i in range ( batch . shape [ 0 ] ) :
batch [ i ] = fix_autoregressive_output ( batch [ i ] , stop_mel_token )
2023-02-11 20:34:12 +00:00
2023-02-05 23:25:41 +00:00
if cvvp_amount != 1 :
clvp = self . clvp ( text_tokens . repeat ( batch . shape [ 0 ] , 1 ) , batch , return_loss = False )
2023-02-11 20:34:12 +00:00
2023-02-05 23:25:41 +00:00
if auto_conds is not None and cvvp_amount > 0 :
cvvp_accumulator = 0
for cl in range ( auto_conds . shape [ 1 ] ) :
cvvp_accumulator = cvvp_accumulator + self . cvvp ( auto_conds [ : , cl ] . repeat ( batch . shape [ 0 ] , 1 , 1 ) , batch , return_loss = False )
cvvp = cvvp_accumulator / auto_conds . shape [ 1 ]
if cvvp_amount == 1 :
clip_results . append ( cvvp )
else :
clip_results . append ( cvvp * cvvp_amount + clvp * ( 1 - cvvp_amount ) )
2022-05-25 11:12:53 +00:00
else :
2023-02-05 23:25:41 +00:00
clip_results . append ( clvp )
2023-02-07 20:55:56 +00:00
2023-02-12 14:46:21 +00:00
if not self . preloaded_tensors and auto_conds is not None :
auto_conds = auto_conds . cpu ( )
2022-05-25 10:22:50 +00:00
clip_results = torch . cat ( clip_results , dim = 0 )
2022-03-29 01:33:31 +00:00
samples = torch . cat ( samples , dim = 0 )
2022-05-25 10:22:50 +00:00
best_results = samples [ torch . topk ( clip_results , k = k ) . indices ]
2023-02-04 01:50:57 +00:00
2023-02-12 14:46:21 +00:00
if not self . preloaded_tensors :
2023-02-04 01:50:57 +00:00
self . clvp = self . clvp . cpu ( )
if self . cvvp is not None :
self . cvvp = self . cvvp . cpu ( )
2022-03-29 01:33:31 +00:00
del samples
2023-02-09 05:05:21 +00:00
if get_device_name ( ) == " dml " :
text_tokens = text_tokens . cpu ( )
best_results = best_results . cpu ( )
auto_conditioning = auto_conditioning . cpu ( )
self . autoregressive = self . autoregressive . cpu ( )
2023-02-12 14:46:21 +00:00
else :
#text_tokens = text_tokens.to(self.device)
#best_results = best_results.to(self.device)
auto_conditioning = auto_conditioning . to ( self . device )
self . autoregressive = self . autoregressive . to ( self . device )
2023-02-09 05:05:21 +00:00
2023-02-12 14:46:21 +00:00
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
2022-05-03 00:00:22 +00:00
best_latents = self . autoregressive ( auto_conditioning . repeat ( k , 1 ) , text_tokens . repeat ( k , 1 ) ,
torch . tensor ( [ text_tokens . shape [ - 1 ] ] , device = text_tokens . device ) , best_results ,
2022-05-02 21:40:03 +00:00
torch . tensor ( [ best_results . shape [ - 1 ] * self . autoregressive . mel_length_compression ] , device = text_tokens . device ) ,
2022-04-13 02:53:09 +00:00
return_latent = True , clip_inputs = False )
2023-02-04 01:50:57 +00:00
2023-02-12 14:46:21 +00:00
diffusion_conditioning = diffusion_conditioning . to ( self . device )
2023-02-09 05:05:21 +00:00
if get_device_name ( ) == " dml " :
self . autoregressive = self . autoregressive . to ( self . device )
best_results = best_results . to ( self . device )
best_latents = best_latents . to ( self . device )
2023-02-12 14:46:21 +00:00
self . vocoder = self . vocoder . cpu ( )
else :
if not self . preloaded_tensors :
self . autoregressive = self . autoregressive . cpu ( )
2023-02-04 01:50:57 +00:00
self . diffusion = self . diffusion . to ( self . device )
self . vocoder = self . vocoder . to ( self . device )
2023-02-12 14:46:21 +00:00
2023-02-09 05:05:21 +00:00
del text_tokens
2022-05-01 23:25:18 +00:00
del auto_conditioning
2022-04-13 02:53:09 +00:00
2022-03-29 01:33:31 +00:00
wav_candidates = [ ]
2023-02-03 04:56:30 +00:00
for b in range ( best_results . shape [ 0 ] ) :
2022-04-11 01:29:32 +00:00
codes = best_results [ b ] . unsqueeze ( 0 )
2022-04-13 02:53:09 +00:00
latents = best_latents [ b ] . unsqueeze ( 0 )
2022-04-11 01:29:32 +00:00
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for k in range ( codes . shape [ - 1 ] ) :
if codes [ 0 , k ] == calm_token :
ctokens + = 1
else :
ctokens = 0
2023-02-05 14:45:51 +00:00
if ctokens > breathing_room : # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
2022-04-13 02:53:09 +00:00
latents = latents [ : , : k ]
2022-04-11 01:29:32 +00:00
break
2022-05-01 23:25:18 +00:00
mel = do_spectrogram_diffusion ( self . diffusion , diffuser , latents , diffusion_conditioning ,
2023-02-07 18:34:29 +00:00
temperature = diffusion_temperature , verbose = verbose , progress = progress , desc = " Transforming autoregressive outputs into audio.. " , sampler = diffusion_sampler ,
input_sample_rate = self . input_sample_rate , output_sample_rate = self . output_sample_rate )
2023-02-09 05:05:21 +00:00
2022-03-29 01:33:31 +00:00
wav = self . vocoder . inference ( mel )
2023-02-12 20:52:04 +00:00
wav_candidates . append ( wav )
2023-02-04 01:50:57 +00:00
2023-02-12 14:46:21 +00:00
if not self . preloaded_tensors :
2023-02-04 01:50:57 +00:00
self . diffusion = self . diffusion . cpu ( )
self . vocoder = self . vocoder . cpu ( )
2022-03-29 01:33:31 +00:00
2022-05-02 21:40:03 +00:00
def potentially_redact ( clip , text ) :
2022-05-02 20:57:29 +00:00
if self . enable_redaction :
2023-02-12 20:52:04 +00:00
return self . aligner . redact ( clip . squeeze ( 1 ) . to ( ' cpu ' if get_device_name ( ) == " dml " else self . device ) , text , self . output_sample_rate ) . unsqueeze ( 1 )
2022-05-02 20:57:29 +00:00
return clip
wav_candidates = [ potentially_redact ( wav_candidate , text ) for wav_candidate in wav_candidates ]
2022-05-17 18:11:18 +00:00
2022-03-29 01:33:31 +00:00
if len ( wav_candidates ) > 1 :
2022-05-17 18:11:18 +00:00
res = wav_candidates
else :
res = wav_candidates [ 0 ]
2023-02-09 22:06:55 +00:00
gc . collect ( )
2022-05-17 18:11:18 +00:00
if return_deterministic_state :
return res , ( deterministic_seed , text , voice_samples , conditioning_latents )
else :
return res
def deterministic_state ( self , seed = None ) :
"""
Sets the random seeds that tortoise uses to the current time ( ) and returns that seed so results can be
reproduced .
"""
seed = int ( time ( ) ) if seed is None else seed
torch . manual_seed ( seed )
random . seed ( seed )
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
2022-05-02 20:57:29 +00:00
2023-02-03 04:20:01 +00:00
return seed