forked from mrq/DL-Art-School
disable checkpointing in conditioning encoder
This commit is contained in:
parent
cc5fc91562
commit
a15970dd97
|
@ -157,11 +157,11 @@ class DiffusionTtsFlat(nn.Module):
|
|||
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True))
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
|
|
Loading…
Reference in New Issue
Block a user