disable checkpointing in conditioning encoder

This commit is contained in:
James Betker 2022-03-24 11:49:04 -06:00
parent cc5fc91562
commit a15970dd97

View File

@ -157,11 +157,11 @@ class DiffusionTtsFlat(nn.Module):
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True))
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
DiffusionLayer(model_channels, dropout, num_heads),