forked from mrq/DL-Art-School
optionally disable checkpointing in x_transformers (and make it so with the cond_encoder in tfdpc_v5)
This commit is contained in:
parent
48270272e7
commit
28d5b6a80a
|
@ -97,9 +97,9 @@ class ConditioningEncoder(nn.Module):
|
|||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=2,
|
||||
do_checkpointing=do_checkpointing
|
||||
)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
def forward(self, x, time_emb):
|
||||
h = self.init(x).permute(0,2,1)
|
||||
|
@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
num_heads=8,
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
):
|
||||
|
@ -140,7 +141,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
self.enable_fp16 = use_fp16
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim)
|
||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning)
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
|
@ -287,7 +288,7 @@ def test_cheater_model():
|
|||
# For music:
|
||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
unconditioned_percentage=.4)
|
||||
unconditioned_percentage=.4, checkpoint_conditioning=False)
|
||||
print_network(model)
|
||||
for k in range(100):
|
||||
o = model(clip, ts, cl)
|
||||
|
@ -406,6 +407,6 @@ def inference_tfdpc5_with_cheater():
|
|||
torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050)
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_cheater_model()
|
||||
test_conditioning_splitting_logic()
|
||||
test_cheater_model()
|
||||
#test_conditioning_splitting_logic()
|
||||
#inference_tfdpc5_with_cheater()
|
||||
|
|
|
@ -774,6 +774,7 @@ class AttentionLayers(nn.Module):
|
|||
use_qk_norm_attn=False,
|
||||
qk_norm_attn_seq_len=None,
|
||||
zero_init_branch_output=False,
|
||||
do_checkpointing=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -786,6 +787,7 @@ class AttentionLayers(nn.Module):
|
|||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
self.causal = causal
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
||||
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
||||
|
@ -977,17 +979,21 @@ class AttentionLayers(nn.Module):
|
|||
else:
|
||||
layer_past = None
|
||||
|
||||
def fake_checkpoint(blk, *args):
|
||||
return blk(*args)
|
||||
chkpt_fn = checkpoint if self.do_checkpointing else fake_checkpoint
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||
out, inter, k, v = chkpt_fn(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||
prev_attn, layer_mem, layer_past)
|
||||
elif layer_type == 'c':
|
||||
if exists(full_context):
|
||||
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||
out, inter, k, v = chkpt_fn(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||
None, prev_attn, None, layer_past)
|
||||
else:
|
||||
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
||||
out, inter, k, v = chkpt_fn(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
||||
elif layer_type == 'f':
|
||||
out = checkpoint(block, x)
|
||||
out = chkpt_fn(block, x)
|
||||
|
||||
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
||||
present_key_values.append((k.detach(), v.detach()))
|
||||
|
|
Loading…
Reference in New Issue
Block a user