diff --git a/api.py b/api.py index 28ce9ed..799bd16 100644 --- a/api.py +++ b/api.py @@ -49,13 +49,13 @@ def download_models(): print('Done.') -def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True): +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): """ Helper function to load a GaussianDiffusion instance configured for use as a vocoder. """ return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), - conditioning_free=cond_free, conditioning_free_k=1) + conditioning_free=cond_free, conditioning_free_k=cond_free_k) def load_conditioning(clip, cond_length=132300): @@ -96,7 +96,7 @@ def fix_autoregressive_output(codes, stop_token): return codes -def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False): +def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, temperature=1): """ Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip. """ @@ -111,11 +111,10 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_ output_shape = (mel.shape[0], 100, mel.shape[-1]*4) precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel) - if mean: - mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device), - model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) - else: - mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) + + noise = torch.randn(output_shape, device=mel_codes.device) * temperature + mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) return denormalize_tacotron_mel(mel)[:,:,:msl*4] @@ -150,7 +149,12 @@ class TextToSpeech: self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) self.vocoder.eval(inference=True) - def tts(self, text, voice_samples, num_autoregressive_samples=512, k=1, diffusion_iterations=100, cond_free=True): + def tts(self, text, voice_samples, k=1, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.9, length_penalty=1, repetition_penalty=1.0, top_k=50, top_p=.95, + typical_sampling=False, typical_mass=.9, + # diffusion generation parameters follow + diffusion_iterations=100, cond_free=True, cond_free_k=1, diffusion_temperature=1,): text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() text = F.pad(text, (0, 1)) # This may not be necessary. @@ -167,7 +171,7 @@ class TextToSpeech: else: cond_diffusion = cond_diffusion[:, :88200] - diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free) + diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) with torch.no_grad(): samples = [] @@ -175,11 +179,16 @@ class TextToSpeech: stop_mel_token = self.autoregressive.stop_mel_token self.autoregressive = self.autoregressive.cuda() for b in tqdm(range(num_batches)): - codes = self.autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, - top_k=50, top_p=.95, - temperature=.9, - num_return_sequences=self.autoregressive_batch_size, - length_penalty=1) + codes = self.autoregressive.inference_speech(conds, text, + do_sample=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + typical_sampling=typical_sampling, + typical_mass=typical_mass) padding_needed = 250 - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) samples.append(codes) @@ -203,7 +212,7 @@ class TextToSpeech: self.vocoder = self.vocoder.cuda() for b in range(best_results.shape[0]): code = best_results[b].unsqueeze(0) - mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, mean=False) + mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, temperature=diffusion_temperature) wav = self.vocoder.inference(mel) wav_candidates.append(wav.cpu()) self.diffusion = self.diffusion.cpu() diff --git a/eval_multiple.py b/eval_multiple.py index 43e3b4a..30bf31f 100644 --- a/eval_multiple.py +++ b/eval_multiple.py @@ -7,7 +7,7 @@ from utils.audio import load_audio if __name__ == '__main__': fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv' - outpath = 'D:\\tmp\\tortoise-tts-eval\\baseline' + outpath = 'D:\\tmp\\tortoise-tts-eval\\redo_outlier' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' os.makedirs(outpath, exist_ok=True) @@ -24,7 +24,8 @@ if __name__ == '__main__': path = os.path.join(os.path.dirname(fname), line[1]) cond_audio = load_audio(path, 22050) torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050) - sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1, diffusion_iterations=200, cond_free=True) + sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, k=1, diffusion_iterations=200, cond_free=False, + top_k=None, top_p=.95, typical_sampling=False, temperature=.7, length_penalty=.5, repetition_penalty=1) down = torchaudio.functional.resample(sample, 24000, 22050) fout_path = os.path.join(outpath, os.path.basename(line[1])) torchaudio.save(fout_path, down.squeeze(0), 22050) diff --git a/models/autoregressive.py b/models/autoregressive.py index c1dea14..6f40ca7 100644 --- a/models/autoregressive.py +++ b/models/autoregressive.py @@ -3,11 +3,11 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F -from transformers import GPT2Config, GPT2PreTrainedModel +from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.utils.model_parallel_utils import get_device_map, assert_device_map from models.arch_util import AttentionBlock - +from utils.typical_sampling import TypicalLogitsWarper def null_position_embeddings(range, dim): @@ -497,7 +497,7 @@ class UnifiedVoice(nn.Module): loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_mel.mean() - def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): + def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 if not hasattr(self, 'inference_model'): # TODO: Decouple gpt_config from this inference model. @@ -530,8 +530,9 @@ class UnifiedVoice(nn.Module): fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs[:,-1] = self.start_mel_token + logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=seq_length, **hf_generate_kwargs) + max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs) return gen[:, fake_inputs.shape[1]:] diff --git a/utils/typical_sampling.py b/utils/typical_sampling.py new file mode 100644 index 0000000..ff6bf48 --- /dev/null +++ b/utils/typical_sampling.py @@ -0,0 +1,33 @@ +import torch +from transformers import LogitsWarper + + +class TypicalLogitsWarper(LogitsWarper): + def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores \ No newline at end of file