From 2a65c982ca6b02b22a1b892fa77f447be982763d Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 21 Mar 2022 15:27:51 -0600 Subject: [PATCH] dont double nest checkpointing --- codes/models/audio/tts/diffusion_encoder.py | 24 ++++----------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/codes/models/audio/tts/diffusion_encoder.py b/codes/models/audio/tts/diffusion_encoder.py index 629bfaed..c1e04905 100644 --- a/codes/models/audio/tts/diffusion_encoder.py +++ b/codes/models/audio/tts/diffusion_encoder.py @@ -30,25 +30,9 @@ class TimeIntegrationBlock(nn.Module): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -class CheckpointedLayer(nn.Module): - """ - Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses - checkpoint for all other args. - """ - def __init__(self, wrap): - super().__init__() - self.wrap = wrap - - def forward(self, x, *args, **kwargs): - for k, v in kwargs.items(): - assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. - partial = functools.partial(self.wrap, **kwargs) - return torch.utils.checkpoint.checkpoint(partial, x, *args) - - class TimestepEmbeddingAttentionLayers(AttentionLayers): """ - Modification of x-transformers.AttentionLayers that performs checkpointing, timestep embeddings and layerdrop. + Modification of x-transformers.AttentionLayers that performs timestep embeddings and layerdrop. """ def __init__( self, @@ -176,11 +160,11 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): # iterate and construct layers for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): if layer_type == 'a': - layer = CheckpointedLayer(Attention(dim, heads = heads, causal = causal, **attn_kwargs)) + layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs) elif layer_type == 'c': - layer = CheckpointedLayer(Attention(dim, heads = heads, **attn_kwargs)) + layer = Attention(dim, heads = heads, **attn_kwargs) elif layer_type == 'f': - layer = CheckpointedLayer(FeedForward(dim, **ff_kwargs)) + layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) else: raise Exception(f'invalid layer type {layer_type}')