another update

This commit is contained in:
James Betker 2022-05-20 11:25:00 -06:00
parent 28f950b7d3
commit 968660c248

View File

@ -146,7 +146,6 @@ class FlatDiffusion(nn.Module):
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)])
self.code_norm = normalization(model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
@ -154,6 +153,12 @@ class FlatDiffusion(nn.Module):
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.code_norm = normalization(model_channels)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
@ -183,7 +188,7 @@ class FlatDiffusion(nn.Module):
groups = {
'minicoder': list(self.contextual_embedder.parameters()),
'layers': list(self.layers.parameters()),
'code_converters': list(self.embeddings.parameters()) + list(self.latent_conditioner.parameters()),
'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()),
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
'time_embed': list(self.time_embed.parameters()),
}
@ -218,6 +223,7 @@ class FlatDiffusion(nn.Module):
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
expanded_code_emb = self.code_converter(expanded_code_emb)
if not return_code_pred:
return expanded_code_emb
@ -311,5 +317,5 @@ if __name__ == '__main__':
# Test with latent aligned conditioning
#o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence, cond)
o = model(clip, ts, aligned_sequence, cond, return_code_pred=True)