Change latent_conditioner back

This commit is contained in:
James Betker 2022-04-11 09:00:13 -06:00
parent 03d0b90bda
commit a3622462c1

View File

@ -152,7 +152,13 @@ class DiffusionTtsFlat(nn.Module):
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.code_norm = normalization(model_channels)
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
@ -179,18 +185,18 @@ class DiffusionTtsFlat(nn.Module):
)
if freeze_everything_except_autoregressive_inputs:
for ap in list(self.latent_converter.parameters()):
ap.ALLOWED_IN_FLAT = True
for p in self.parameters():
if not hasattr(p, 'ALLOWED_IN_FLAT'):
p.requires_grad = False
p.DO_NOT_TRAIN = True
p.requires_grad = False
p.DO_NOT_TRAIN = True
for ap in list(self.latent_conditioner.parameters()):
ap.requires_grad = True
del ap.DO_NOT_TRAIN
def get_grad_norm_parameter_groups(self):
groups = {
'minicoder': list(self.contextual_embedder.parameters()),
'layers': list(self.layers.parameters()),
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
'time_embed': list(self.time_embed.parameters()),
}
@ -211,7 +217,7 @@ class DiffusionTtsFlat(nn.Module):
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning)
code_emb = self.latent_conditioner(aligned_conditioning)
else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb)
@ -254,7 +260,7 @@ class DiffusionTtsFlat(nn.Module):
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_converter.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
else:
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
@ -263,7 +269,7 @@ class DiffusionTtsFlat(nn.Module):
if is_latent(aligned_conditioning):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:
unused_params.extend(list(self.latent_converter.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
unused_params.append(self.unconditioned_embedding)