optionally disable checkpointing in x_transformers (and make it so with the cond_encoder in tfdpc_v5)
This commit is contained in:
parent
48270272e7
commit
28d5b6a80a
|
@ -97,9 +97,9 @@ class ConditioningEncoder(nn.Module):
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
zero_init_branch_output=True,
|
zero_init_branch_output=True,
|
||||||
ff_mult=2,
|
ff_mult=2,
|
||||||
|
do_checkpointing=do_checkpointing
|
||||||
)
|
)
|
||||||
self.dim = embedding_dim
|
self.dim = embedding_dim
|
||||||
self.do_checkpointing = do_checkpointing
|
|
||||||
|
|
||||||
def forward(self, x, time_emb):
|
def forward(self, x, time_emb):
|
||||||
h = self.init(x).permute(0,2,1)
|
h = self.init(x).permute(0,2,1)
|
||||||
|
@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
):
|
):
|
||||||
|
@ -140,7 +141,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim)
|
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning)
|
||||||
|
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
@ -287,7 +288,7 @@ def test_cheater_model():
|
||||||
# For music:
|
# For music:
|
||||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||||
unconditioned_percentage=.4)
|
unconditioned_percentage=.4, checkpoint_conditioning=False)
|
||||||
print_network(model)
|
print_network(model)
|
||||||
for k in range(100):
|
for k in range(100):
|
||||||
o = model(clip, ts, cl)
|
o = model(clip, ts, cl)
|
||||||
|
@ -406,6 +407,6 @@ def inference_tfdpc5_with_cheater():
|
||||||
torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050)
|
torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#test_cheater_model()
|
test_cheater_model()
|
||||||
test_conditioning_splitting_logic()
|
#test_conditioning_splitting_logic()
|
||||||
#inference_tfdpc5_with_cheater()
|
#inference_tfdpc5_with_cheater()
|
||||||
|
|
|
@ -774,6 +774,7 @@ class AttentionLayers(nn.Module):
|
||||||
use_qk_norm_attn=False,
|
use_qk_norm_attn=False,
|
||||||
qk_norm_attn_seq_len=None,
|
qk_norm_attn_seq_len=None,
|
||||||
zero_init_branch_output=False,
|
zero_init_branch_output=False,
|
||||||
|
do_checkpointing=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -786,6 +787,7 @@ class AttentionLayers(nn.Module):
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
|
self.do_checkpointing = do_checkpointing
|
||||||
|
|
||||||
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
||||||
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
||||||
|
@ -977,17 +979,21 @@ class AttentionLayers(nn.Module):
|
||||||
else:
|
else:
|
||||||
layer_past = None
|
layer_past = None
|
||||||
|
|
||||||
|
def fake_checkpoint(blk, *args):
|
||||||
|
return blk(*args)
|
||||||
|
chkpt_fn = checkpoint if self.do_checkpointing else fake_checkpoint
|
||||||
|
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
out, inter, k, v = chkpt_fn(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||||
prev_attn, layer_mem, layer_past)
|
prev_attn, layer_mem, layer_past)
|
||||||
elif layer_type == 'c':
|
elif layer_type == 'c':
|
||||||
if exists(full_context):
|
if exists(full_context):
|
||||||
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
out, inter, k, v = chkpt_fn(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||||
None, prev_attn, None, layer_past)
|
None, prev_attn, None, layer_past)
|
||||||
else:
|
else:
|
||||||
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
out, inter, k, v = chkpt_fn(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
||||||
elif layer_type == 'f':
|
elif layer_type == 'f':
|
||||||
out = checkpoint(block, x)
|
out = chkpt_fn(block, x)
|
||||||
|
|
||||||
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
||||||
present_key_values.append((k.detach(), v.detach()))
|
present_key_values.append((k.detach(), v.detach()))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user