forked from mrq/DL-Art-School
dont double nest checkpointing
This commit is contained in:
parent
723f324eda
commit
2a65c982ca
|
@ -30,25 +30,9 @@ class TimeIntegrationBlock(nn.Module):
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedLayer(nn.Module):
|
|
||||||
"""
|
|
||||||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
|
||||||
checkpoint for all other args.
|
|
||||||
"""
|
|
||||||
def __init__(self, wrap):
|
|
||||||
super().__init__()
|
|
||||||
self.wrap = wrap
|
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
|
||||||
partial = functools.partial(self.wrap, **kwargs)
|
|
||||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
||||||
"""
|
"""
|
||||||
Modification of x-transformers.AttentionLayers that performs checkpointing, timestep embeddings and layerdrop.
|
Modification of x-transformers.AttentionLayers that performs timestep embeddings and layerdrop.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -176,11 +160,11 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
||||||
# iterate and construct layers
|
# iterate and construct layers
|
||||||
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
layer = CheckpointedLayer(Attention(dim, heads = heads, causal = causal, **attn_kwargs))
|
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
|
||||||
elif layer_type == 'c':
|
elif layer_type == 'c':
|
||||||
layer = CheckpointedLayer(Attention(dim, heads = heads, **attn_kwargs))
|
layer = Attention(dim, heads = heads, **attn_kwargs)
|
||||||
elif layer_type == 'f':
|
elif layer_type == 'f':
|
||||||
layer = CheckpointedLayer(FeedForward(dim, **ff_kwargs))
|
layer = FeedForward(dim, **ff_kwargs)
|
||||||
layer = layer if not macaron else Scale(0.5, layer)
|
layer = layer if not macaron else Scale(0.5, layer)
|
||||||
else:
|
else:
|
||||||
raise Exception(f'invalid layer type {layer_type}')
|
raise Exception(f'invalid layer type {layer_type}')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user