forked from mrq/DL-Art-School
attempt to force meaningful codes by adding a surrogate loss
This commit is contained in:
parent
45804177b8
commit
2a29a71c37
|
@ -166,6 +166,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
DiffusionLayer(model_channels, dropout, num_heads),
|
DiffusionLayer(model_channels, dropout, num_heads),
|
||||||
)
|
)
|
||||||
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
|
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
|
||||||
|
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
||||||
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
||||||
|
@ -228,12 +229,14 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
device=code_emb.device) < self.unconditioned_percentage
|
device=code_emb.device) < self.unconditioned_percentage
|
||||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
||||||
code_emb)
|
code_emb)
|
||||||
|
expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
|
||||||
|
mel_pred = self.mel_head(expanded_code_emb)
|
||||||
|
|
||||||
# Everything after this comment is timestep dependent.
|
# Everything after this comment is timestep dependent.
|
||||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
code_emb = self.conditioning_timestep_integrator(expanded_code_emb, time_emb)
|
||||||
x = self.inp_block(x)
|
x = self.inp_block(x)
|
||||||
x = torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1)
|
x = torch.cat([x, code_emb], dim=1)
|
||||||
x = self.integrating_conv(x)
|
x = self.integrating_conv(x)
|
||||||
for i, lyr in enumerate(self.layers):
|
for i, lyr in enumerate(self.layers):
|
||||||
# Do layer drop where applicable. Do not drop first and last layers.
|
# Do layer drop where applicable. Do not drop first and last layers.
|
||||||
|
@ -253,7 +256,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
extraneous_addition = extraneous_addition + p.mean()
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
out = out + extraneous_addition * 0
|
out = out + extraneous_addition * 0
|
||||||
|
|
||||||
return out
|
return out, mel_pred
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
@ -269,7 +272,7 @@ if __name__ == '__main__':
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = DiffusionTtsFlat(512, layer_drop=.3)
|
model = DiffusionTtsFlat(512, layer_drop=.3)
|
||||||
# Test with latent aligned conditioning
|
# Test with latent aligned conditioning
|
||||||
o = model(clip, ts, aligned_latent, cond)
|
#o = model(clip, ts, aligned_latent, cond)
|
||||||
# Test with sequence aligned conditioning
|
# Test with sequence aligned conditioning
|
||||||
o = model(clip, ts, aligned_sequence, cond)
|
o = model(clip, ts, aligned_sequence, cond)
|
||||||
|
|
||||||
|
|
|
@ -318,7 +318,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_gpt_tts_unified.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -241,12 +241,13 @@ class ConfigurableStep(Module):
|
||||||
# Finally, compute the losses.
|
# Finally, compute the losses.
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for loss_name, loss in self.losses.items():
|
for loss_name, loss in self.losses.items():
|
||||||
|
multiplier = 1
|
||||||
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
|
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
|
||||||
# be very disruptive to a generator.
|
# be very disruptive to a generator.
|
||||||
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
|
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
|
||||||
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
|
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
|
||||||
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
|
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
|
||||||
continue
|
multiplier = 0 # Multiply by 0 so gradients still flow and DDP works. Effectively this means the loss is unused.
|
||||||
if loss.is_stateful():
|
if loss.is_stateful():
|
||||||
l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||||
local_state.update(lstate)
|
local_state.update(lstate)
|
||||||
|
@ -255,7 +256,7 @@ class ConfigurableStep(Module):
|
||||||
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||||
if not l.isfinite():
|
if not l.isfinite():
|
||||||
print(f'!!Detected non-finite loss {loss_name}')
|
print(f'!!Detected non-finite loss {loss_name}')
|
||||||
total_loss += l * self.weights[loss_name]
|
total_loss += l * self.weights[loss_name] * multiplier
|
||||||
# Record metrics.
|
# Record metrics.
|
||||||
if isinstance(l, torch.Tensor):
|
if isinstance(l, torch.Tensor):
|
||||||
loss_accumulator.add_loss(loss_name, l)
|
loss_accumulator.add_loss(loss_name, l)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user