From a15970dd9785d19f7917c34d1b0e59b6396c47cf Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 24 Mar 2022 11:49:04 -0600 Subject: [PATCH] disable checkpointing in conditioning encoder --- codes/models/audio/tts/unet_diffusion_tts_flat0.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index 76078164..4ee4cd35 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -157,11 +157,11 @@ class DiffusionTtsFlat(nn.Module): self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True)) + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( DiffusionLayer(model_channels, dropout, num_heads),