From 958c6d2f735854ecb054dc7f6431c75485ffa485 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 15 Jun 2022 22:09:15 -0600 Subject: [PATCH] Get rid of checkpointing It isn't needed in inference. --- tortoise/models/arch_util.py | 2 +- tortoise/models/classifier.py | 11 +---------- tortoise/models/cvvp.py | 3 +-- tortoise/models/xtransformers.py | 22 +++++++++------------- 4 files changed, 12 insertions(+), 26 deletions(-) diff --git a/tortoise/models/arch_util.py b/tortoise/models/arch_util.py index 6a79194..ffce5cf 100644 --- a/tortoise/models/arch_util.py +++ b/tortoise/models/arch_util.py @@ -342,7 +342,7 @@ class CheckpointedLayer(nn.Module): for k, v in kwargs.items(): assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. partial = functools.partial(self.wrap, **kwargs) - return torch.utils.checkpoint.checkpoint(partial, x, *args) + return partial(x, *args) class CheckpointedXTransformerEncoder(nn.Module): diff --git a/tortoise/models/classifier.py b/tortoise/models/classifier.py index ce574ea..f92d99e 100644 --- a/tortoise/models/classifier.py +++ b/tortoise/models/classifier.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from torch.utils.checkpoint import checkpoint from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock @@ -64,14 +63,6 @@ class ResBlock(nn.Module): self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) def forward(self, x): - if self.do_checkpoint: - return checkpoint( - self._forward, x - ) - else: - return self._forward(x) - - def _forward(self, x): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) @@ -125,7 +116,7 @@ class AudioMiniEncoder(nn.Module): h = self.res(h) h = self.final(h) for blk in self.attn: - h = checkpoint(blk, h) + h = blk(h) return h[:, :, 0] diff --git a/tortoise/models/cvvp.py b/tortoise/models/cvvp.py index 622dd60..544ca47 100644 --- a/tortoise/models/cvvp.py +++ b/tortoise/models/cvvp.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum -from torch.utils.checkpoint import checkpoint from tortoise.models.arch_util import AttentionBlock from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder @@ -44,7 +43,7 @@ class CollapsingTransformer(nn.Module): def forward(self, x, **transformer_kwargs): h = self.transformer(x, **transformer_kwargs) h = h.permute(0, 2, 1) - h = checkpoint(self.pre_combiner, h).permute(0, 2, 1) + h = self.pre_combiner(h).permute(0, 2, 1) if self.training: mask = torch.rand_like(h.float()) > self.mask_percentage else: diff --git a/tortoise/models/xtransformers.py b/tortoise/models/xtransformers.py index df9ee25..8be2df4 100644 --- a/tortoise/models/xtransformers.py +++ b/tortoise/models/xtransformers.py @@ -1,16 +1,12 @@ -import functools import math -import torch -from torch import nn, einsum -import torch.nn.functional as F +from collections import namedtuple from functools import partial from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce -from einops.layers.torch import Rearrange - -from torch.utils.checkpoint import checkpoint +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn, einsum DEFAULT_DIM_HEAD = 64 @@ -969,16 +965,16 @@ class AttentionLayers(nn.Module): layer_past = None if layer_type == 'a': - out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, + out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past) elif layer_type == 'c': if exists(full_context): - out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, + out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn, None, layer_past) else: - out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) + out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) elif layer_type == 'f': - out = checkpoint(block, x) + out = block(x) if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: present_key_values.append((k.detach(), v.detach()))