dont double nest checkpointing

This commit is contained in:
James Betker 2022-03-21 15:27:51 -06:00
parent 723f324eda
commit 2a65c982ca

View File

@ -30,25 +30,9 @@ class TimeIntegrationBlock(nn.Module):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
class TimestepEmbeddingAttentionLayers(AttentionLayers):
"""
Modification of x-transformers.AttentionLayers that performs checkpointing, timestep embeddings and layerdrop.
Modification of x-transformers.AttentionLayers that performs timestep embeddings and layerdrop.
"""
def __init__(
self,
@ -176,11 +160,11 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
# iterate and construct layers
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
if layer_type == 'a':
layer = CheckpointedLayer(Attention(dim, heads = heads, causal = causal, **attn_kwargs))
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
elif layer_type == 'c':
layer = CheckpointedLayer(Attention(dim, heads = heads, **attn_kwargs))
layer = Attention(dim, heads = heads, **attn_kwargs)
elif layer_type == 'f':
layer = CheckpointedLayer(FeedForward(dim, **ff_kwargs))
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')