Update API to have more expressive interface for controlling various generation knobs

- Also adds typical decoder support; unfortunately this does not work well with the current model.
This commit is contained in:
James Betker 2022-03-29 13:59:39 -06:00
parent f1adc12505
commit 57c2ce7040
4 changed files with 66 additions and 22 deletions

41
api.py
View File

@ -49,13 +49,13 @@ def download_models():
print('Done.')
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free, conditioning_free_k=1)
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
def load_conditioning(clip, cond_length=132300):
@ -96,7 +96,7 @@ def fix_autoregressive_output(codes, stop_token):
return codes
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, temperature=1):
"""
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
"""
@ -111,11 +111,10 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
if mean:
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
else:
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
@ -150,7 +149,12 @@ class TextToSpeech:
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
self.vocoder.eval(inference=True)
def tts(self, text, voice_samples, num_autoregressive_samples=512, k=1, diffusion_iterations=100, cond_free=True):
def tts(self, text, voice_samples, k=1,
# autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.9, length_penalty=1, repetition_penalty=1.0, top_k=50, top_p=.95,
typical_sampling=False, typical_mass=.9,
# diffusion generation parameters follow
diffusion_iterations=100, cond_free=True, cond_free_k=1, diffusion_temperature=1,):
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
text = F.pad(text, (0, 1)) # This may not be necessary.
@ -167,7 +171,7 @@ class TextToSpeech:
else:
cond_diffusion = cond_diffusion[:, :88200]
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
with torch.no_grad():
samples = []
@ -175,11 +179,16 @@ class TextToSpeech:
stop_mel_token = self.autoregressive.stop_mel_token
self.autoregressive = self.autoregressive.cuda()
for b in tqdm(range(num_batches)):
codes = self.autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True,
top_k=50, top_p=.95,
temperature=.9,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=1)
codes = self.autoregressive.inference_speech(conds, text,
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
typical_sampling=typical_sampling,
typical_mass=typical_mass)
padding_needed = 250 - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
@ -203,7 +212,7 @@ class TextToSpeech:
self.vocoder = self.vocoder.cuda()
for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0)
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, mean=False)
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, temperature=diffusion_temperature)
wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu()

View File

@ -7,7 +7,7 @@ from utils.audio import load_audio
if __name__ == '__main__':
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
outpath = 'D:\\tmp\\tortoise-tts-eval\\baseline'
outpath = 'D:\\tmp\\tortoise-tts-eval\\redo_outlier'
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
os.makedirs(outpath, exist_ok=True)
@ -24,7 +24,8 @@ if __name__ == '__main__':
path = os.path.join(os.path.dirname(fname), line[1])
cond_audio = load_audio(path, 22050)
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1, diffusion_iterations=200, cond_free=True)
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, k=1, diffusion_iterations=200, cond_free=False,
top_k=None, top_p=.95, typical_sampling=False, temperature=.7, length_penalty=.5, repetition_penalty=1)
down = torchaudio.functional.resample(sample, 24000, 22050)
fout_path = os.path.join(outpath, os.path.basename(line[1]))
torchaudio.save(fout_path, down.squeeze(0), 22050)

View File

@ -3,11 +3,11 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.arch_util import AttentionBlock
from utils.typical_sampling import TypicalLogitsWarper
def null_position_embeddings(range, dim):
@ -497,7 +497,7 @@ class UnifiedVoice(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
@ -530,8 +530,9 @@ class UnifiedVoice(nn.Module):
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=seq_length, **hf_generate_kwargs)
max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]

33
utils/typical_sampling.py Normal file
View File

@ -0,0 +1,33 @@
import torch
from transformers import LogitsWarper
class TypicalLogitsWarper(LogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores