Get rid of checkpointing

It isn't needed in inference.
This commit is contained in:
James Betker 2022-06-15 22:09:15 -06:00
parent 1aa4e0d4b8
commit e5201bf14e
3 changed files with 11 additions and 24 deletions

View File

@ -342,7 +342,7 @@ class CheckpointedLayer(nn.Module):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
return partial(x, *args)
class CheckpointedXTransformerEncoder(nn.Module):

View File

@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
@ -64,14 +63,6 @@ class ResBlock(nn.Module):
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
def forward(self, x):
if self.do_checkpoint:
return checkpoint(
self._forward, x
)
else:
return self._forward(x)
def _forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
@ -125,7 +116,7 @@ class AudioMiniEncoder(nn.Module):
h = self.res(h)
h = self.final(h)
for blk in self.attn:
h = checkpoint(blk, h)
h = blk(h)
return h[:, :, 0]

View File

@ -1,16 +1,12 @@
import functools
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from collections import namedtuple
from functools import partial
from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.utils.checkpoint import checkpoint
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn, einsum
DEFAULT_DIM_HEAD = 64
@ -969,16 +965,16 @@ class AttentionLayers(nn.Module):
layer_past = None
if layer_type == 'a':
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem, layer_past)
elif layer_type == 'c':
if exists(full_context):
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn, None, layer_past)
else:
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f':
out = checkpoint(block, x)
out = block(x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
present_key_values.append((k.detach(), v.detach()))