From 28d5b6a80aecdcbfd110066e8e2f3dd599f826b3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 6 Jul 2022 16:55:57 -0600 Subject: [PATCH] optionally disable checkpointing in x_transformers (and make it so with the cond_encoder in tfdpc_v5) --- codes/models/audio/music/tfdpc_v5.py | 11 ++++++----- codes/models/lucidrains/x_transformers.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index e1df8de0..38f2a9a4 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -97,9 +97,9 @@ class ConditioningEncoder(nn.Module): rotary_pos_emb=True, zero_init_branch_output=True, ff_mult=2, + do_checkpointing=do_checkpointing ) self.dim = embedding_dim - self.do_checkpointing = do_checkpointing def forward(self, x, time_emb): h = self.init(x).permute(0,2,1) @@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): num_heads=8, dropout=0, use_fp16=False, + checkpoint_conditioning=True, # This will need to be false for DDP training. :( # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. ): @@ -140,7 +141,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): self.enable_fp16 = use_fp16 self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) - self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim) + self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning) self.time_embed = nn.Sequential( linear(time_embed_dim, time_embed_dim), @@ -287,7 +288,7 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, contraction_dim=512, num_heads=8, num_layers=15, dropout=0, - unconditioned_percentage=.4) + unconditioned_percentage=.4, checkpoint_conditioning=False) print_network(model) for k in range(100): o = model(clip, ts, cl) @@ -406,6 +407,6 @@ def inference_tfdpc5_with_cheater(): torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050) if __name__ == '__main__': - #test_cheater_model() - test_conditioning_splitting_logic() + test_cheater_model() + #test_conditioning_splitting_logic() #inference_tfdpc5_with_cheater() diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 7056a485..02ef7d59 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -774,6 +774,7 @@ class AttentionLayers(nn.Module): use_qk_norm_attn=False, qk_norm_attn_seq_len=None, zero_init_branch_output=False, + do_checkpointing=True, **kwargs ): super().__init__() @@ -786,6 +787,7 @@ class AttentionLayers(nn.Module): self.depth = depth self.layers = nn.ModuleList([]) self.causal = causal + self.do_checkpointing = do_checkpointing rel_pos_bias = 'rel_pos_bias' in attn_kwargs self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb @@ -977,17 +979,21 @@ class AttentionLayers(nn.Module): else: layer_past = None + def fake_checkpoint(blk, *args): + return blk(*args) + chkpt_fn = checkpoint if self.do_checkpointing else fake_checkpoint + if layer_type == 'a': - out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, + out, inter, k, v = chkpt_fn(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past) elif layer_type == 'c': if exists(full_context): - out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, + out, inter, k, v = chkpt_fn(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn, None, layer_past) else: - out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) + out, inter, k, v = chkpt_fn(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) elif layer_type == 'f': - out = checkpoint(block, x) + out = chkpt_fn(block, x) if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: present_key_values.append((k.detach(), v.detach()))