From db0c3340ac95f534ce932fc4deb5b46280714e1c Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 1 Mar 2022 11:49:36 -0700 Subject: [PATCH] Implement guidance-free diffusion in eval And a few other fixes --- codes/data/audio/fast_paired_dataset.py | 24 +++++---- codes/models/diffusion/gaussian_diffusion.py | 9 +++- codes/models/gpt_voice/unet_diffusion_tts7.py | 52 ++++++++++--------- codes/train.py | 2 +- codes/trainer/eval/audio_diffusion_fid.py | 9 ++-- codes/trainer/injectors/base_injectors.py | 4 ++ 6 files changed, 61 insertions(+), 39 deletions(-) diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index f18a2077..12f8585c 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -263,19 +263,25 @@ if __name__ == '__main__': batch_sz = 256 params = { 'mode': 'fast_paired_voice_audio', - 'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'], + 'path': ['y:/libritts/train-other-500/transcribed-oco.tsv', + 'y:/libritts/train-clean-100/transcribed-oco.tsv', + 'y:/libritts/train-clean-360/transcribed-oco.tsv', + 'y:/clips/books1/transcribed-w2v.tsv', + 'y:/clips/books2/transcribed-w2v.tsv', + 'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv'], 'phase': 'train', 'n_workers': 0, 'batch_size': batch_sz, - 'max_wav_length': 255995, + 'max_wav_length': 163840, 'max_text_length': 200, 'sample_rate': 22050, 'load_conditioning': True, 'num_conditioning_candidates': 1, 'conditioning_length': 66000, - 'use_bpe_tokenizer': True, - 'load_aligned_codes': True, - 'produce_ctc_metadata': True, + 'use_bpe_tokenizer': False, + 'load_aligned_codes': False, + 'needs_collate': False, + 'produce_ctc_metadata': False, } from data import create_dataset, create_dataloader @@ -294,10 +300,10 @@ if __name__ == '__main__': for ib in range(batch_sz): #max_pads = max(max_pads, b['ctc_pads'].max()) #max_repeats = max(max_repeats, b['ctc_repeats'].max()) - #print(f'{i} {ib} {b["real_text"][ib]}') - #save(b, i, ib, 'wav') + print(f'{i} {ib} {b["real_text"][ib]}') + save(b, i, ib, 'wav') pass - #if i > 5: - # break + if i > 15: + break print(max_pads, max_repeats) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 21700fc1..3a3d7752 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -125,6 +125,7 @@ class GaussianDiffusion: rescale_timesteps=False, conditioning_free=False, conditioning_free_k=1, + ramp_conditioning_free=True, ): self.model_mean_type = ModelMeanType(model_mean_type) self.model_var_type = ModelVarType(model_var_type) @@ -132,6 +133,7 @@ class GaussianDiffusion: self.rescale_timesteps = rescale_timesteps self.conditioning_free = conditioning_free self.conditioning_free_k = conditioning_free_k + self.ramp_conditioning_free = ramp_conditioning_free # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) @@ -299,7 +301,12 @@ class GaussianDiffusion: model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) if self.conditioning_free: - model_output = (1 + self.conditioning_free_k) * model_output - self.conditioning_free_k * model_output_no_conditioning + if self.ramp_conditioning_free: + assert t.shape[0] == 1 # This should only be used in inference. + cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) + else: + cfk = self.conditioning_free_k + model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning def process_xstart(x): if denoised_fn is not None: diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index b54fbe8c..bf05e125 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -408,7 +408,7 @@ class DiffusionTts(nn.Module): ) - def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None): + def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False): """ Apply the model to an input batch. @@ -419,6 +419,7 @@ class DiffusionTts(nn.Module): :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate. :param unaligned_input: A structural input that is not properly aligned with the output of the diffusion model. Can be combined with a conditioning input to produce more robust conditioning. + :param conditioning_free: When set, all conditioning inputs (including tokens, conditioning_input and unaligned_input) will not be considered. :return: an [N x C x ...] Tensor of outputs. """ assert conditioning_input is not None @@ -430,11 +431,6 @@ class DiffusionTts(nn.Module): lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') x = torch.cat([x, lr_input], dim=1) - if self.enable_unaligned_inputs: - assert unaligned_input is not None - unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1) - unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1) - with autocast(x.device.type): orig_x_shape = x.shape[-1] cm = ceil_multiple(x.shape[-1], 2048) @@ -447,28 +443,36 @@ class DiffusionTts(nn.Module): hs = [] time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - cond_emb = self.contextual_embedder(conditioning_input) - if tokens is not None: - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True) - tokens = torch.where(token_mask, self.mask_token_id, tokens) - code_emb = self.code_embedding(tokens).permute(0,2,1) - cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) - cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove. - cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) - code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1)) + if conditioning_free: + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) else: - code_emb = cond_emb.unsqueeze(-1) - if self.enable_unaligned_inputs: - code_emb = self.conditioning_encoder(code_emb, context=unaligned_h) - else: - code_emb = self.conditioning_encoder(code_emb) + if self.enable_unaligned_inputs: + assert unaligned_input is not None + unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1) + unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1) + + cond_emb = self.contextual_embedder(conditioning_input) + if tokens is not None: + # Mask out guidance tokens for un-guided diffusion. + if self.training and self.nil_guidance_fwd_proportion > 0: + token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True) + tokens = torch.where(token_mask, self.mask_token_id, tokens) + code_emb = self.code_embedding(tokens).permute(0,2,1) + cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) + cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove. + cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) + code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1)) + else: + code_emb = cond_emb.unsqueeze(-1) + if self.enable_unaligned_inputs: + code_emb = self.conditioning_encoder(code_emb, context=unaligned_h) + else: + code_emb = self.conditioning_encoder(code_emb) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.unconditioned_percentage > 0: + if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, code_emb.shape[2]), code_emb) + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), code_emb) first = True time_emb = time_emb.float() diff --git a/codes/train.py b/codes/train.py index e496bc51..58e840d0 100644 --- a/codes/train.py +++ b/codes/train.py @@ -317,7 +317,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/debug_diffusion_tts7.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index f584c0f1..0be66264 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -41,7 +41,7 @@ class AudioDiffusionFid(evaluator.Evaluator): conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False) conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1) self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule, - conditioning_free_diffusion_enabled=conditioning_free_diffusion_enabled, + enable_conditioning_free_guidance=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k) self.dev = self.env['device'] mode = opt_get(opt_eval, ['diffusion_type'], 'tts') @@ -162,9 +162,10 @@ if __name__ == '__main__': from utils.util import load_model_from_config diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator', - also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\5500_generator_ema.pth').cuda() - opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50, + also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\39500_generator_ema.pth').cuda() + opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100, + 'conditioning_free': True, 'conditioning_free_k': 2, 'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 202, 'device': 'cuda', 'opt': {}} eval = AudioDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval()) \ No newline at end of file diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 15cd211c..62675b73 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -92,8 +92,12 @@ class GeneratorInjector(Injector): if self.grad: results = method(*params, **self.args) else: + was_training = gen.training + gen.eval() with torch.no_grad(): results = method(*params, **self.args) + if was_training: + gen.train() new_state = {} if isinstance(self.output, list): # Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG