attempt to force meaningful codes by adding a surrogate loss

This commit is contained in:
James Betker 2022-03-26 08:31:40 -06:00
parent 45804177b8
commit 2a29a71c37
3 changed files with 11 additions and 7 deletions

View File

@ -166,6 +166,7 @@ class DiffusionTtsFlat(nn.Module):
DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads),
) )
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
@ -228,12 +229,14 @@ class DiffusionTtsFlat(nn.Module):
device=code_emb.device) < self.unconditioned_percentage device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
code_emb) code_emb)
expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
mel_pred = self.mel_head(expanded_code_emb)
# Everything after this comment is timestep dependent. # Everything after this comment is timestep dependent.
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) code_emb = self.conditioning_timestep_integrator(expanded_code_emb, time_emb)
x = self.inp_block(x) x = self.inp_block(x)
x = torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1) x = torch.cat([x, code_emb], dim=1)
x = self.integrating_conv(x) x = self.integrating_conv(x)
for i, lyr in enumerate(self.layers): for i, lyr in enumerate(self.layers):
# Do layer drop where applicable. Do not drop first and last layers. # Do layer drop where applicable. Do not drop first and last layers.
@ -253,7 +256,7 @@ class DiffusionTtsFlat(nn.Module):
extraneous_addition = extraneous_addition + p.mean() extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0 out = out + extraneous_addition * 0
return out return out, mel_pred
@register_model @register_model
@ -269,7 +272,7 @@ if __name__ == '__main__':
ts = torch.LongTensor([600, 600]) ts = torch.LongTensor([600, 600])
model = DiffusionTtsFlat(512, layer_drop=.3) model = DiffusionTtsFlat(512, layer_drop=.3)
# Test with latent aligned conditioning # Test with latent aligned conditioning
o = model(clip, ts, aligned_latent, cond) #o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning # Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence, cond) o = model(clip, ts, aligned_sequence, cond)

View File

@ -318,7 +318,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_gpt_tts_unified.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)

View File

@ -241,12 +241,13 @@ class ConfigurableStep(Module):
# Finally, compute the losses. # Finally, compute the losses.
total_loss = 0 total_loss = 0
for loss_name, loss in self.losses.items(): for loss_name, loss in self.losses.items():
multiplier = 1
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can # Some losses only activate after a set number of steps. For example, proto-discriminator losses can
# be very disruptive to a generator. # be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \ if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \ 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0: 'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
continue multiplier = 0 # Multiply by 0 so gradients still flow and DDP works. Effectively this means the loss is unused.
if loss.is_stateful(): if loss.is_stateful():
l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state) l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
local_state.update(lstate) local_state.update(lstate)
@ -255,7 +256,7 @@ class ConfigurableStep(Module):
l = loss(self.get_network_for_name(self.step_opt['training']), local_state) l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
if not l.isfinite(): if not l.isfinite():
print(f'!!Detected non-finite loss {loss_name}') print(f'!!Detected non-finite loss {loss_name}')
total_loss += l * self.weights[loss_name] total_loss += l * self.weights[loss_name] * multiplier
# Record metrics. # Record metrics.
if isinstance(l, torch.Tensor): if isinstance(l, torch.Tensor):
loss_accumulator.add_loss(loss_name, l) loss_accumulator.add_loss(loss_name, l)