Compare commits

..

2 Commits
main ... main

Author SHA1 Message Date
mrq
95f679f4ba possible fix for when candidates >= samples 2023-10-10 15:30:08 +00:00
mrq
bf3b6c87aa added compat for coqui's XTTS 2023-09-16 03:38:21 +00:00
2 changed files with 178 additions and 183 deletions

View File

@ -51,7 +51,6 @@ MODELS = {
'bigvgan_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.json',
}
def hash_file(path, algo="md5", buffer_size=0):
import hashlib
@ -78,14 +77,12 @@ def hash_file(path, algo="md5", buffer_size=0):
return "{0}".format(hash.hexdigest())
def check_for_kill_signal():
global STOP_SIGNAL
if STOP_SIGNAL:
STOP_SIGNAL = False
raise Exception("Kill signal detected")
def download_models(specific_models=None):
"""
Call to download all the models that Tortoise uses.
@ -105,7 +102,6 @@ def download_models(specific_models=None):
else:
pbar.finish()
pbar = None
for model_name, url in MODELS.items():
if specific_models is not None and model_name not in specific_models:
continue
@ -146,18 +142,14 @@ def pad_or_truncate(t, length):
return t[..., :length]
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True,
cond_free_k=1):
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse',
betas=get_named_beta_schedule('linear', trained_diffusion_steps),
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
@torch.inference_mode()
def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=22050):
"""
@ -173,7 +165,6 @@ def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=2
mel_clip = mel_clip.unsqueeze(0)
return migrate_to_device(mel_clip, device)
def fix_autoregressive_output(codes, stop_token, complain=True):
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
@ -203,19 +194,15 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
return codes
@torch.inference_mode()
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True,
desc=None, sampler="P", input_sample_rate=22050, output_sample_rate=24000):
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True, desc=None, sampler="P", input_sample_rate=22050, output_sample_rate=24000):
"""
Uses the specified diffusion model to convert discrete codes into a spectrogram.
"""
with torch.no_grad():
output_seq_len = latents.shape[
1] * 4 * output_sample_rate // input_sample_rate # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_seq_len = latents.shape[1] * 4 * output_sample_rate // input_sample_rate # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len,
False)
precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False)
noise = torch.randn(output_shape, device=latents.device) * temperature
@ -243,7 +230,6 @@ def classify_audio_clip(clip):
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]
def migrate_to_device( t, device ):
if t is None:
return t
@ -263,7 +249,6 @@ def migrate_to_device(t, device):
return t
class TextToSpeech:
"""
Main entry point into Tortoise.
@ -274,7 +259,8 @@ class TextToSpeech:
unsqueeze_sample_batches=False,
input_sample_rate=22050, output_sample_rate=24000,
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
):
# ):
use_deepspeed=False): # Add use_deepspeed parameter
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -295,7 +281,8 @@ class TextToSpeech:
self.output_sample_rate = output_sample_rate
self.minor_optimizations = minor_optimizations
self.unsqueeze_sample_batches = unsqueeze_sample_batches
self.use_deepspeed = use_deepspeed # Store use_deepspeed as an instance variable
print(f'use_deepspeed api_debug {use_deepspeed}')
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
self.preloaded_tensors = minor_optimizations
self.use_kv_cache = minor_optimizations
@ -328,6 +315,7 @@ class TextToSpeech:
self.load_diffusion_model(diffusion_model_path)
self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
text_seq_len=350, text_heads=12,
num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430,
@ -350,13 +338,11 @@ class TextToSpeech:
self.loading = False
def load_autoregressive_model(self, autoregressive_model_path):
if hasattr(self, "autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path,
autoregressive_model_path):
def load_autoregressive_model(self, autoregressive_model_path, is_xtts=False):
if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path):
return
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(
autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
new_hash = hash_file(self.autoregressive_model_path)
if hasattr(self,"autoregressive_model_hash") and self.autoregressive_model_hash == new_hash:
@ -370,13 +356,40 @@ class TextToSpeech:
if hasattr(self, 'autoregressive'):
del self.autoregressive
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2,
layers=30,
model_dim=1024,
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
# XTTS requires a different "dimensionality" for its autoregressive model
if new_hash == "e4ce21eae0043f7691d6a6c8540b74b8" or is_xtts:
dimensionality = {
"max_mel_tokens": 605,
"max_text_tokens": 402,
"max_prompt_tokens": 70,
"max_conditioning_inputs": 1,
"layers": 30,
"model_dim": 1024,
"heads": 16,
"number_text_tokens": 5023, # -1
"start_text_token": 261,
"stop_text_token": 0,
"number_mel_codes": 8194,
"start_mel_token": 8192,
"stop_mel_token": 8193,
}
else:
dimensionality = {
"max_mel_tokens": 604,
"max_text_tokens": 402,
"max_conditioning_inputs": 2,
"layers": 30,
"model_dim": 1024,
"heads": 16,
"number_text_tokens": 255,
"start_text_token": 255,
"checkpointing": False,
"train_solo_embeddings": False
}
self.autoregressive = UnifiedVoice(**dimensionality).cpu().eval()
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache)
if self.preloaded_tensors:
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
@ -389,16 +402,27 @@ class TextToSpeech:
self.loading = True
self.diffusion_model_path = diffusion_model_path if diffusion_model_path and os.path.exists(
diffusion_model_path) else get_model_path('diffusion_decoder.pth', self.models_dir)
self.diffusion_model_path = diffusion_model_path if diffusion_model_path and os.path.exists(diffusion_model_path) else get_model_path('diffusion_decoder.pth', self.models_dir)
self.diffusion_model_hash = hash_file(self.diffusion_model_path)
if hasattr(self, 'diffusion'):
del self.diffusion
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()
# XTTS does not require a different "dimensionality" for its diffusion model
dimensionality = {
"model_channels": 1024,
"num_layers": 10,
"in_channels": 100,
"out_channels": 200,
"in_latent_channels": 1024,
"in_tokens": 8193,
"dropout": 0,
"use_fp16": False,
"num_heads": 16,
"layer_drop": 0,
"unconditioned_percentage": 0
}
self.diffusion = DiffusionTts(**dimensionality)
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir)))
if self.preloaded_tensors:
self.diffusion = migrate_to_device( self.diffusion, self.device )
@ -438,9 +462,7 @@ class TextToSpeech:
self.vocoder = UnivNetGenerator().cpu()
print(f"Loading vocoder model: {self.vocoder_model_path}")
self.vocoder.load_state_dict(
torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[
vocoder_key])
self.vocoder.load_state_dict(torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[vocoder_key])
self.vocoder.eval(inference=True)
if self.preloaded_tensors:
@ -453,8 +475,7 @@ class TextToSpeech:
return
self.loading = True
self.tokenizer_json = tokenizer_json if tokenizer_json else os.path.join(
os.path.dirname(os.path.realpath(__file__)), '../tortoise/data/tokenizer.json')
self.tokenizer_json = tokenizer_json if tokenizer_json else os.path.join(os.path.dirname(os.path.realpath(__file__)), '../tortoise/data/tokenizer.json')
print("Loading tokenizer JSON:", self.tokenizer_json)
if hasattr(self, 'tokenizer'):
@ -467,8 +488,7 @@ class TextToSpeech:
def load_cvvp(self):
"""Load CVVP model."""
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8,
cond_mask_percentage=0,
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
@ -476,23 +496,12 @@ class TextToSpeech:
self.cvvp = migrate_to_device( self.cvvp, self.device )
@torch.inference_mode()
def get_conditioning_latents(
self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False,
original_ar=False, original_diffusion=False
):
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False, original_ar=False, original_diffusion=False):
"""
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
properties.
:param force_cpu:
:param max_chunk_size:
:param slices:
:param verbose:
:param return_mels:
:param original_diffusion:
:param original_ar:
:param voice_samples: List of 2 or more ~10 second reference clips,
which should be torch tensors containing 22.05kHz waveform data.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
"""
with torch.no_grad():
@ -530,8 +539,7 @@ class TextToSpeech:
if original_ar:
samples = [resampler_22K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate,
cond_length=132300))
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate, cond_length=132300))
else:
samples = [resampler_22K(sample) for sample in voice_samples]
concat = torch.cat(samples, dim=-1)
@ -548,30 +556,27 @@ class TextToSpeech:
chunk_size = chunks[0].shape[-1]
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate,
cond_length=chunk_size))
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
if original_diffusion:
samples = [resampler_24K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing diffusion conditioning latents..."):
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(migrate_to_device(sample, device), do_normalization=False,
device=self.device)
cond_mel = wav_to_univnet_mel(migrate_to_device(sample, device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel)
else:
samples = [resampler_24K(sample) for sample in voice_samples]
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device(chunk, device), do_normalization=False,
device=device)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device(self.autoregressive,
self.device if self.preloaded_tensors else 'cpu')
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = migrate_to_device( self.diffusion, device )
@ -587,11 +592,9 @@ class TextToSpeech:
# Lazy-load the RLG models.
if self.rlg_auto is None:
self.rlg_auto = RandomLatentConverter(1024).eval()
self.rlg_auto.load_state_dict(
torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(
torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
@ -613,8 +616,6 @@ class TextToSpeech:
'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 80},
'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
'narration': {'num_autoregressive_samples': 30, 'diffusion_iterations': 80, "diffusion_sampler": "DDIM"},
'dialogue': {'num_autoregressive_samples': 60, 'diffusion_iterations': 120, "diffusion_sampler": "DDIM"}
}
settings.update(presets[preset])
settings.update(kwargs) # allow overriding of preset settings with kwargs
@ -624,8 +625,7 @@ class TextToSpeech:
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
return_deterministic_state=False,
# autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8,
max_mel_tokens=500,
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
sample_batch_size=None,
autoregressive_model=None,
diffusion_model=None,
@ -710,14 +710,11 @@ class TextToSpeech:
text_tokens = migrate_to_device( text_tokens, self.device )
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert text_tokens.shape[
-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
auto_conds = None
if voice_samples is not None:
auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples,
return_mels=True,
verbose=True)
auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True, verbose=True)
elif conditioning_latents is not None:
latent_tuple = conditioning_latents
if len(latent_tuple) == 2:
@ -727,8 +724,7 @@ class TextToSpeech:
else:
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free,
cond_free_k=cond_free_k)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
self.autoregressive_batch_size = get_device_batch_size() if sample_batch_size is None or sample_batch_size == 0 else sample_batch_size
@ -745,7 +741,7 @@ class TextToSpeech:
text_tokens = migrate_to_device( text_tokens, self.device )
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
for b in tqdm(range(num_batches), desc="Generating autoregressive samples", disable=not verbose):
for b in tqdm(range(num_batches), desc="Generating autoregressive samples"):
check_for_kill_signal()
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
@ -793,7 +789,8 @@ class TextToSpeech:
else:
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
for batch in tqdm(samples, desc=desc, disable=not verbose):
for batch in tqdm(samples, desc=desc):
check_for_kill_signal()
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
@ -804,8 +801,7 @@ class TextToSpeech:
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(
auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
@ -819,12 +815,16 @@ class TextToSpeech:
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
if k < num_autoregressive_samples:
best_results = samples[torch.topk(clip_results, k=k).indices]
else:
best_results = samples
if not self.preloaded_tensors:
self.clvp = migrate_to_device( self.clvp, 'cpu' )
self.cvvp = migrate_to_device( self.cvvp, 'cpu' )
if get_device_name() == "dml":
text_tokens = migrate_to_device( text_tokens, 'cpu' )
best_results = migrate_to_device( best_results, 'cpu' )
@ -840,11 +840,8 @@ class TextToSpeech:
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
best_results,
torch.tensor([best_results.shape[
-1] * self.autoregressive.mel_length_compression],
device=text_tokens.device),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
diffusion_conditioning = migrate_to_device( diffusion_conditioning, self.device )
@ -881,11 +878,8 @@ class TextToSpeech:
break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning,
temperature=diffusion_temperature,
desc="Transforming autoregressive outputs into audio..",
sampler=diffusion_sampler,
input_sample_rate=self.input_sample_rate,
output_sample_rate=self.output_sample_rate)
temperature=diffusion_temperature, desc="Transforming autoregressive outputs into audio..", sampler=diffusion_sampler,
input_sample_rate=self.input_sample_rate, output_sample_rate=self.output_sample_rate)
wav = self.vocoder.inference(mel)
wav_candidates.append(wav)
@ -900,7 +894,6 @@ class TextToSpeech:
t = migrate_to_device( t, 'cpu' if get_device_name() == "dml" else self.device)
return self.aligner.redact(t, text, self.output_sample_rate).unsqueeze(1)
return clip
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
if len(wav_candidates) > 1:

View File

@ -283,9 +283,9 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_prompt_tokens=2, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256,
start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
start_text_token=None, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, types=1):
"""
@ -295,6 +295,7 @@ class UnifiedVoice(nn.Module):
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_prompt_tokens: compat set to 2, 70 for XTTS
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
@ -311,7 +312,7 @@ class UnifiedVoice(nn.Module):
self.number_text_tokens = number_text_tokens
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.stop_text_token = 0
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
@ -319,6 +320,7 @@ class UnifiedVoice(nn.Module):
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.max_prompt_tokens = max_prompt_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
@ -353,7 +355,7 @@ class UnifiedVoice(nn.Module):
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
@ -494,7 +496,7 @@ class UnifiedVoice(nn.Module):
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
if not hasattr(self, 'inference_model'):
self.post_init_gpt2_config(kv_cache=self.kv_cache)