forked from mrq/DL-Art-School
another update
This commit is contained in:
parent
28f950b7d3
commit
968660c248
|
@ -146,7 +146,6 @@ class FlatDiffusion(nn.Module):
|
|||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||
# transformer network.
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)])
|
||||
self.code_norm = normalization(model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
|
@ -154,6 +153,12 @@ class FlatDiffusion(nn.Module):
|
|||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
)
|
||||
self.code_converter = nn.Sequential(
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
)
|
||||
self.code_norm = normalization(model_channels)
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
|
@ -183,7 +188,7 @@ class FlatDiffusion(nn.Module):
|
|||
groups = {
|
||||
'minicoder': list(self.contextual_embedder.parameters()),
|
||||
'layers': list(self.layers.parameters()),
|
||||
'code_converters': list(self.embeddings.parameters()) + list(self.latent_conditioner.parameters()),
|
||||
'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()),
|
||||
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
|
||||
'time_embed': list(self.time_embed.parameters()),
|
||||
}
|
||||
|
@ -218,6 +223,7 @@ class FlatDiffusion(nn.Module):
|
|||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
expanded_code_emb = self.code_converter(expanded_code_emb)
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
|
@ -311,5 +317,5 @@ if __name__ == '__main__':
|
|||
# Test with latent aligned conditioning
|
||||
#o = model(clip, ts, aligned_latent, cond)
|
||||
# Test with sequence aligned conditioning
|
||||
o = model(clip, ts, aligned_sequence, cond)
|
||||
o = model(clip, ts, aligned_sequence, cond, return_code_pred=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user