forked from mrq/DL-Art-School
dont double nest checkpointing
This commit is contained in:
parent
723f324eda
commit
2a65c982ca
|
@ -30,25 +30,9 @@ class TimeIntegrationBlock(nn.Module):
|
|||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class CheckpointedLayer(nn.Module):
|
||||
"""
|
||||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||
checkpoint for all other args.
|
||||
"""
|
||||
def __init__(self, wrap):
|
||||
super().__init__()
|
||||
self.wrap = wrap
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
||||
partial = functools.partial(self.wrap, **kwargs)
|
||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
||||
|
||||
|
||||
class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
||||
"""
|
||||
Modification of x-transformers.AttentionLayers that performs checkpointing, timestep embeddings and layerdrop.
|
||||
Modification of x-transformers.AttentionLayers that performs timestep embeddings and layerdrop.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -176,11 +160,11 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
# iterate and construct layers
|
||||
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
||||
if layer_type == 'a':
|
||||
layer = CheckpointedLayer(Attention(dim, heads = heads, causal = causal, **attn_kwargs))
|
||||
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = CheckpointedLayer(Attention(dim, heads = heads, **attn_kwargs))
|
||||
layer = Attention(dim, heads = heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
layer = CheckpointedLayer(FeedForward(dim, **ff_kwargs))
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
else:
|
||||
raise Exception(f'invalid layer type {layer_type}')
|
||||
|
|
Loading…
Reference in New Issue
Block a user