Get rid of checkpointing

It isn't needed in inference.
This commit is contained in:
James Betker 2022-06-15 22:09:15 -06:00
parent 29c1d9e561
commit 958c6d2f73
4 changed files with 12 additions and 26 deletions

View File

@ -342,7 +342,7 @@ class CheckpointedLayer(nn.Module):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
return partial(x, *args)
class CheckpointedXTransformerEncoder(nn.Module):

View File

@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
@ -64,14 +63,6 @@ class ResBlock(nn.Module):
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
def forward(self, x):
if self.do_checkpoint:
return checkpoint(
self._forward, x
)
else:
return self._forward(x)
def _forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
@ -125,7 +116,7 @@ class AudioMiniEncoder(nn.Module):
h = self.res(h)
h = self.final(h)
for blk in self.attn:
h = checkpoint(blk, h)
h = blk(h)
return h[:, :, 0]

View File

@ -2,7 +2,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from torch.utils.checkpoint import checkpoint
from tortoise.models.arch_util import AttentionBlock
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
@ -44,7 +43,7 @@ class CollapsingTransformer(nn.Module):
def forward(self, x, **transformer_kwargs):
h = self.transformer(x, **transformer_kwargs)
h = h.permute(0, 2, 1)
h = checkpoint(self.pre_combiner, h).permute(0, 2, 1)
h = self.pre_combiner(h).permute(0, 2, 1)
if self.training:
mask = torch.rand_like(h.float()) > self.mask_percentage
else:

View File

@ -1,16 +1,12 @@
import functools
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from collections import namedtuple
from functools import partial
from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.utils.checkpoint import checkpoint
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn, einsum
DEFAULT_DIM_HEAD = 64
@ -969,16 +965,16 @@ class AttentionLayers(nn.Module):
layer_past = None
if layer_type == 'a':
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem, layer_past)
elif layer_type == 'c':
if exists(full_context):
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn, None, layer_past)
else:
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f':
out = checkpoint(block, x)
out = block(x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
present_key_values.append((k.detach(), v.detach()))