forked from mrq/tortoise-tts
Get rid of checkpointing
It isn't needed in inference.
This commit is contained in:
parent
29c1d9e561
commit
958c6d2f73
|
@ -342,7 +342,7 @@ class CheckpointedLayer(nn.Module):
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
||||||
partial = functools.partial(self.wrap, **kwargs)
|
partial = functools.partial(self.wrap, **kwargs)
|
||||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
return partial(x, *args)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedXTransformerEncoder(nn.Module):
|
class CheckpointedXTransformerEncoder(nn.Module):
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
||||||
|
|
||||||
|
@ -64,14 +63,6 @@ class ResBlock(nn.Module):
|
||||||
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
|
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.do_checkpoint:
|
|
||||||
return checkpoint(
|
|
||||||
self._forward, x
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._forward(x)
|
|
||||||
|
|
||||||
def _forward(self, x):
|
|
||||||
if self.updown:
|
if self.updown:
|
||||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||||
h = in_rest(x)
|
h = in_rest(x)
|
||||||
|
@ -125,7 +116,7 @@ class AudioMiniEncoder(nn.Module):
|
||||||
h = self.res(h)
|
h = self.res(h)
|
||||||
h = self.final(h)
|
h = self.final(h)
|
||||||
for blk in self.attn:
|
for blk in self.attn:
|
||||||
h = checkpoint(blk, h)
|
h = blk(h)
|
||||||
return h[:, :, 0]
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
from tortoise.models.arch_util import AttentionBlock
|
from tortoise.models.arch_util import AttentionBlock
|
||||||
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
|
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
|
||||||
|
@ -44,7 +43,7 @@ class CollapsingTransformer(nn.Module):
|
||||||
def forward(self, x, **transformer_kwargs):
|
def forward(self, x, **transformer_kwargs):
|
||||||
h = self.transformer(x, **transformer_kwargs)
|
h = self.transformer(x, **transformer_kwargs)
|
||||||
h = h.permute(0, 2, 1)
|
h = h.permute(0, 2, 1)
|
||||||
h = checkpoint(self.pre_combiner, h).permute(0, 2, 1)
|
h = self.pre_combiner(h).permute(0, 2, 1)
|
||||||
if self.training:
|
if self.training:
|
||||||
mask = torch.rand_like(h.float()) > self.mask_percentage
|
mask = torch.rand_like(h.float()) > self.mask_percentage
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,16 +1,12 @@
|
||||||
import functools
|
|
||||||
import math
|
import math
|
||||||
import torch
|
from collections import namedtuple
|
||||||
from torch import nn, einsum
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
from einops import rearrange, repeat, reduce
|
import torch
|
||||||
from einops.layers.torch import Rearrange
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch import nn, einsum
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
DEFAULT_DIM_HEAD = 64
|
||||||
|
|
||||||
|
@ -969,16 +965,16 @@ class AttentionLayers(nn.Module):
|
||||||
layer_past = None
|
layer_past = None
|
||||||
|
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
||||||
prev_attn, layer_mem, layer_past)
|
prev_attn, layer_mem, layer_past)
|
||||||
elif layer_type == 'c':
|
elif layer_type == 'c':
|
||||||
if exists(full_context):
|
if exists(full_context):
|
||||||
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
|
||||||
None, prev_attn, None, layer_past)
|
None, prev_attn, None, layer_past)
|
||||||
else:
|
else:
|
||||||
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
||||||
elif layer_type == 'f':
|
elif layer_type == 'f':
|
||||||
out = checkpoint(block, x)
|
out = block(x)
|
||||||
|
|
||||||
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
||||||
present_key_values.append((k.detach(), v.detach()))
|
present_key_values.append((k.detach(), v.detach()))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user