Get rid of checkpointing

It isn't needed in inference.
This commit is contained in:
James Betker 2022-06-15 22:09:15 -06:00
parent 1aa4e0d4b8
commit e5201bf14e
3 changed files with 11 additions and 24 deletions

View File

@ -342,7 +342,7 @@ class CheckpointedLayer(nn.Module):
for k, v in kwargs.items(): for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs) partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args) return partial(x, *args)
class CheckpointedXTransformerEncoder(nn.Module): class CheckpointedXTransformerEncoder(nn.Module):

View File

@ -1,6 +1,5 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
@ -64,14 +63,6 @@ class ResBlock(nn.Module):
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
def forward(self, x): def forward(self, x):
if self.do_checkpoint:
return checkpoint(
self._forward, x
)
else:
return self._forward(x)
def _forward(self, x):
if self.updown: if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x) h = in_rest(x)
@ -125,7 +116,7 @@ class AudioMiniEncoder(nn.Module):
h = self.res(h) h = self.res(h)
h = self.final(h) h = self.final(h)
for blk in self.attn: for blk in self.attn:
h = checkpoint(blk, h) h = blk(h)
return h[:, :, 0] return h[:, :, 0]

View File

@ -1,16 +1,12 @@
import functools
import math import math
import torch from collections import namedtuple
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial from functools import partial
from inspect import isfunction from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce import torch
from einops.layers.torch import Rearrange import torch.nn.functional as F
from einops import rearrange, repeat
from torch.utils.checkpoint import checkpoint from torch import nn, einsum
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
@ -969,16 +965,16 @@ class AttentionLayers(nn.Module):
layer_past = None layer_past = None
if layer_type == 'a': if layer_type == 'a':
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem, layer_past) prev_attn, layer_mem, layer_past)
elif layer_type == 'c': elif layer_type == 'c':
if exists(full_context): if exists(full_context):
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn, None, layer_past) None, prev_attn, None, layer_past)
else: else:
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f': elif layer_type == 'f':
out = checkpoint(block, x) out = block(x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
present_key_values.append((k.detach(), v.detach())) present_key_values.append((k.detach(), v.detach()))