diff --git a/do_tts.py b/do_tts.py index a3587c1..8473fa2 100644 --- a/do_tts.py +++ b/do_tts.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F import torchaudio import progressbar +import ocotillo from models.diffusion_decoder import DiffusionTts from models.autoregressive import UnifiedVoice @@ -17,7 +18,7 @@ from models.text_voice_clip import VoiceCLIP from models.vocoder import UnivNetGenerator from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule -from utils.tokenizer import VoiceBpeTokenizer +from utils.tokenizer import VoiceBpeTokenizer, lev_distance pbar = None def download_models(): @@ -47,13 +48,13 @@ def download_models(): print('Done.') -def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200): +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True): """ Helper function to load a GaussianDiffusion instance configured for use as a vocoder. """ return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), - conditioning_free=True, conditioning_free_k=1) + conditioning_free=cond_free, conditioning_free_k=1) def load_conditioning(path, sample_rate=22050, cond_length=132300): @@ -109,11 +110,12 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_ mel = torch.nn.functional.pad(mel_codes, (0, gap)) output_shape = (mel.shape[0], 100, mel.shape[-1]*4) + precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel) if mean: mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device), - model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel}) + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) else: - mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel}) + mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) return denormalize_tacotron_mel(mel)[:,:,:msl*4] @@ -136,9 +138,9 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol') - parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512) - parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16) - parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2) + parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=1024) + parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=32) + parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16) parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/') args = parser.parse_args() @@ -192,7 +194,7 @@ if __name__ == '__main__': return_loss=False)) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) - best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices] + best_results = samples[torch.topk(clip_results, k=args.num_diffusion_samples).indices] # Delete the autoregressive and clip models to free up GPU memory del samples, clip @@ -210,12 +212,32 @@ if __name__ == '__main__': vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) vocoder = vocoder.cuda() vocoder.eval(inference=True) - diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100) + initial_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=40, cond_free=False) + final_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=500) print("Performing vocoding..") - # Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive. + wav_candidates = [] for b in range(best_results.shape[0]): code = best_results[b].unsqueeze(0) - mel = do_spectrogram_diffusion(diffusion, diffuser, code, cond_diffusion, mean=False) + mel = do_spectrogram_diffusion(diffusion, initial_diffuser, code, cond_diffusion, mean=False) wav = vocoder.inference(mel) - torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 24000) + wav_candidates.append(wav.cpu()) + + # Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable. + transcriber = ocotillo.Transcriber(on_cuda=True) + transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000) + best = 99999999 + for i, transcription in enumerate(transcriptions): + dist = lev_distance(transcription, args.text.lower()) + if dist < best: + best = dist + best_codes = best_results[i].unsqueeze(0) + best_wav = wav_candidates[i] + del transcriber + torchaudio.save(os.path.join(args.output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000) + + # Perform diffusion again with the high-quality diffuser. + mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False) + wav = vocoder.inference(mel) + torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000) + diff --git a/models/diffusion_decoder.py b/models/diffusion_decoder.py index c946663..7a3bb4d 100644 --- a/models/diffusion_decoder.py +++ b/models/diffusion_decoder.py @@ -486,66 +486,40 @@ class DiffusionTts(nn.Module): aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1]))) return x, aligned_conditioning - def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. - :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded. - :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate. - :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. - :return: an [N x C x ...] Tensor of outputs. - """ - assert conditioning_input is not None - if self.super_sampling_enabled: - assert lr_input is not None - if self.training and self.super_sampling_max_noising_factor > 0: - noising_factor = random.uniform(0,self.super_sampling_max_noising_factor) - lr_input = torch.randn_like(lr_input) * noising_factor + lr_input - lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') - x = torch.cat([x, lr_input], dim=1) - + def timestep_independent(self, aligned_conditioning, conditioning_input): # Shuffle aligned_latent to BxCxS format if is_latent(aligned_conditioning): aligned_conditioning = aligned_conditioning.permute(0, 2, 1) - # Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net. - orig_x_shape = x.shape[-1] - x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning) + with autocast(aligned_conditioning.device.type, enabled=self.enable_fp16): + cond_emb = self.contextual_embedder(conditioning_input) + if len(cond_emb.shape) == 3: # Just take the first element. + cond_emb = cond_emb[:, :, 0] + if is_latent(aligned_conditioning): + code_emb = self.latent_converter(aligned_conditioning) + else: + code_emb = self.code_converter(aligned_conditioning) + cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1]) + code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1)) + return code_emb + + def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False): + assert x.shape[-1] % self.alignment_size == 0 with autocast(x.device.type, enabled=self.enable_fp16): - hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - # Note: this block does not need to repeated on inference, since it is not timestep-dependent. if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) else: - cond_emb = self.contextual_embedder(conditioning_input) - if len(cond_emb.shape) == 3: # Just take the first element. - cond_emb = cond_emb[:, :, 0] - if is_latent(aligned_conditioning): - code_emb = self.latent_converter(aligned_conditioning) - else: - code_emb = self.code_converter(aligned_conditioning) - cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1]) - code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1)) - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), - code_emb) + code_emb = precomputed_aligned_embeddings - # Everything after this comment is timestep dependent. + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) first = True time_emb = time_emb.float() h = x + hs = [] for k, module in enumerate(self.input_blocks): if isinstance(module, nn.Conv1d): h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') @@ -565,14 +539,7 @@ class DiffusionTts(nn.Module): h = h.float() out = self.out(h) - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) - for p in params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out[:, :, :orig_x_shape] + return out if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index 568575c..1e695a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ inflect progressbar einops unidecode -x-transformers \ No newline at end of file +x-transformers +ocotillo \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/tokenizer.py b/utils/tokenizer.py index 53cd2de..ed7e4cd 100644 --- a/utils/tokenizer.py +++ b/utils/tokenizer.py @@ -148,6 +148,20 @@ def english_cleaners(text): text = text.replace('"', '') return text +def lev_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] class VoiceBpeTokenizer: def __init__(self, vocab_file='data/tokenizer.json'):