optionally disable checkpointing in x_transformers (and make it so with the cond_encoder in tfdpc_v5)

This commit is contained in:
James Betker 2022-07-06 16:55:57 -06:00
parent 48270272e7
commit 28d5b6a80a
2 changed files with 16 additions and 9 deletions

View File

@ -97,9 +97,9 @@ class ConditioningEncoder(nn.Module):
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=2,
do_checkpointing=do_checkpointing
)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x, time_emb):
h = self.init(x).permute(0,2,1)
@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
num_heads=8,
dropout=0,
use_fp16=False,
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
):
@ -140,7 +141,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
self.enable_fp16 = use_fp16
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim)
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning)
self.time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim),
@ -287,7 +288,7 @@ def test_cheater_model():
# For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4)
unconditioned_percentage=.4, checkpoint_conditioning=False)
print_network(model)
for k in range(100):
o = model(clip, ts, cl)
@ -406,6 +407,6 @@ def inference_tfdpc5_with_cheater():
torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050)
if __name__ == '__main__':
#test_cheater_model()
test_conditioning_splitting_logic()
test_cheater_model()
#test_conditioning_splitting_logic()
#inference_tfdpc5_with_cheater()

View File

@ -774,6 +774,7 @@ class AttentionLayers(nn.Module):
use_qk_norm_attn=False,
qk_norm_attn_seq_len=None,
zero_init_branch_output=False,
do_checkpointing=True,
**kwargs
):
super().__init__()
@ -786,6 +787,7 @@ class AttentionLayers(nn.Module):
self.depth = depth
self.layers = nn.ModuleList([])
self.causal = causal
self.do_checkpointing = do_checkpointing
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
@ -977,17 +979,21 @@ class AttentionLayers(nn.Module):
else:
layer_past = None
def fake_checkpoint(blk, *args):
return blk(*args)
chkpt_fn = checkpoint if self.do_checkpointing else fake_checkpoint
if layer_type == 'a':
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
out, inter, k, v = chkpt_fn(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem, layer_past)
elif layer_type == 'c':
if exists(full_context):
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
out, inter, k, v = chkpt_fn(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn, None, layer_past)
else:
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
out, inter, k, v = chkpt_fn(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f':
out = checkpoint(block, x)
out = chkpt_fn(block, x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
present_key_values.append((k.detach(), v.detach()))