forked from mrq/tortoise-tts
Get rid of checkpointing
It isn't needed in inference.
This commit is contained in:
parent
29c1d9e561
commit
958c6d2f73
|
@ -342,7 +342,7 @@ class CheckpointedLayer(nn.Module):
|
|||
for k, v in kwargs.items():
|
||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
||||
partial = functools.partial(self.wrap, **kwargs)
|
||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
||||
return partial(x, *args)
|
||||
|
||||
|
||||
class CheckpointedXTransformerEncoder(nn.Module):
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
||||
|
||||
|
@ -64,14 +63,6 @@ class ResBlock(nn.Module):
|
|||
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.do_checkpoint:
|
||||
return checkpoint(
|
||||
self._forward, x
|
||||
)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
|
@ -125,7 +116,7 @@ class AudioMiniEncoder(nn.Module):
|
|||
h = self.res(h)
|
||||
h = self.final(h)
|
||||
for blk in self.attn:
|
||||
h = checkpoint(blk, h)
|
||||
h = blk(h)
|
||||
return h[:, :, 0]
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from tortoise.models.arch_util import AttentionBlock
|
||||
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
|
||||
|
@ -44,7 +43,7 @@ class CollapsingTransformer(nn.Module):
|
|||
def forward(self, x, **transformer_kwargs):
|
||||
h = self.transformer(x, **transformer_kwargs)
|
||||
h = h.permute(0, 2, 1)
|
||||
h = checkpoint(self.pre_combiner, h).permute(0, 2, 1)
|
||||
h = self.pre_combiner(h).permute(0, 2, 1)
|
||||
if self.training:
|
||||
mask = torch.rand_like(h.float()) > self.mask_percentage
|
||||
else:
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
import functools
|
||||
import math
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from collections import namedtuple
|
||||
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn, einsum
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
|
@ -969,16 +965,16 @@ class AttentionLayers(nn.Module):
|
|||
layer_past = None
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||
prev_attn, layer_mem, layer_past)
|
||||
elif layer_type == 'c':
|
||||
if exists(full_context):
|
||||
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||
None, prev_attn, None, layer_past)
|
||||
else:
|
||||
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
||||
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
||||
elif layer_type == 'f':
|
||||
out = checkpoint(block, x)
|
||||
out = block(x)
|
||||
|
||||
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
||||
present_key_values.append((k.detach(), v.detach()))
|
||||
|
|
Loading…
Reference in New Issue
Block a user