Get rid of checkpointing

It isn't needed in inference.
This commit is contained in:
James Betker 2022-06-15 22:09:15 -06:00
parent 29c1d9e561
commit 958c6d2f73
4 changed files with 12 additions and 26 deletions

View File

@ -342,7 +342,7 @@ class CheckpointedLayer(nn.Module):
for k, v in kwargs.items(): for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs) partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args) return partial(x, *args)
class CheckpointedXTransformerEncoder(nn.Module): class CheckpointedXTransformerEncoder(nn.Module):

View File

@ -1,6 +1,5 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
@ -64,14 +63,6 @@ class ResBlock(nn.Module):
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
def forward(self, x): def forward(self, x):
if self.do_checkpoint:
return checkpoint(
self._forward, x
)
else:
return self._forward(x)
def _forward(self, x):
if self.updown: if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x) h = in_rest(x)
@ -125,7 +116,7 @@ class AudioMiniEncoder(nn.Module):
h = self.res(h) h = self.res(h)
h = self.final(h) h = self.final(h)
for blk in self.attn: for blk in self.attn:
h = checkpoint(blk, h) h = blk(h)
return h[:, :, 0] return h[:, :, 0]

View File

@ -2,7 +2,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import einsum from torch import einsum
from torch.utils.checkpoint import checkpoint
from tortoise.models.arch_util import AttentionBlock from tortoise.models.arch_util import AttentionBlock
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
@ -44,7 +43,7 @@ class CollapsingTransformer(nn.Module):
def forward(self, x, **transformer_kwargs): def forward(self, x, **transformer_kwargs):
h = self.transformer(x, **transformer_kwargs) h = self.transformer(x, **transformer_kwargs)
h = h.permute(0, 2, 1) h = h.permute(0, 2, 1)
h = checkpoint(self.pre_combiner, h).permute(0, 2, 1) h = self.pre_combiner(h).permute(0, 2, 1)
if self.training: if self.training:
mask = torch.rand_like(h.float()) > self.mask_percentage mask = torch.rand_like(h.float()) > self.mask_percentage
else: else:

View File

@ -1,16 +1,12 @@
import functools
import math import math
import torch from collections import namedtuple
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial from functools import partial
from inspect import isfunction from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce import torch
from einops.layers.torch import Rearrange import torch.nn.functional as F
from einops import rearrange, repeat
from torch.utils.checkpoint import checkpoint from torch import nn, einsum
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
@ -969,16 +965,16 @@ class AttentionLayers(nn.Module):
layer_past = None layer_past = None
if layer_type == 'a': if layer_type == 'a':
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem, layer_past) prev_attn, layer_mem, layer_past)
elif layer_type == 'c': elif layer_type == 'c':
if exists(full_context): if exists(full_context):
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn, None, layer_past) None, prev_attn, None, layer_past)
else: else:
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f': elif layer_type == 'f':
out = checkpoint(block, x) out = block(x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
present_key_values.append((k.detach(), v.detach())) present_key_values.append((k.detach(), v.detach()))