From 2a29a71c3791b2a5e2096b1ee13cd0d065a1344f Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 26 Mar 2022 08:31:40 -0600 Subject: [PATCH] attempt to force meaningful codes by adding a surrogate loss --- codes/models/audio/tts/unet_diffusion_tts_flat0.py | 11 +++++++---- codes/train.py | 2 +- codes/trainer/steps.py | 5 +++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index c56903c0..14f59f00 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -166,6 +166,7 @@ class DiffusionTtsFlat(nn.Module): DiffusionLayer(model_channels, dropout, num_heads), ) self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) @@ -228,12 +229,14 @@ class DiffusionTtsFlat(nn.Module): device=code_emb.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), code_emb) + expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest') + mel_pred = self.mel_head(expanded_code_emb) # Everything after this comment is timestep dependent. time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + code_emb = self.conditioning_timestep_integrator(expanded_code_emb, time_emb) x = self.inp_block(x) - x = torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1) + x = torch.cat([x, code_emb], dim=1) x = self.integrating_conv(x) for i, lyr in enumerate(self.layers): # Do layer drop where applicable. Do not drop first and last layers. @@ -253,7 +256,7 @@ class DiffusionTtsFlat(nn.Module): extraneous_addition = extraneous_addition + p.mean() out = out + extraneous_addition * 0 - return out + return out, mel_pred @register_model @@ -269,7 +272,7 @@ if __name__ == '__main__': ts = torch.LongTensor([600, 600]) model = DiffusionTtsFlat(512, layer_drop=.3) # Test with latent aligned conditioning - o = model(clip, ts, aligned_latent, cond) + #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence, cond) diff --git a/codes/train.py b/codes/train.py index ee85b690..80e246dc 100644 --- a/codes/train.py +++ b/codes/train.py @@ -318,7 +318,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_gpt_tts_unified.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 95a9fa95..886c4d09 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -241,12 +241,13 @@ class ConfigurableStep(Module): # Finally, compute the losses. total_loss = 0 for loss_name, loss in self.losses.items(): + multiplier = 1 # Some losses only activate after a set number of steps. For example, proto-discriminator losses can # be very disruptive to a generator. if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \ 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \ 'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0: - continue + multiplier = 0 # Multiply by 0 so gradients still flow and DDP works. Effectively this means the loss is unused. if loss.is_stateful(): l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state) local_state.update(lstate) @@ -255,7 +256,7 @@ class ConfigurableStep(Module): l = loss(self.get_network_for_name(self.step_opt['training']), local_state) if not l.isfinite(): print(f'!!Detected non-finite loss {loss_name}') - total_loss += l * self.weights[loss_name] + total_loss += l * self.weights[loss_name] * multiplier # Record metrics. if isinstance(l, torch.Tensor): loss_accumulator.add_loss(loss_name, l)