Merge remote-tracking branch 'origin/master'

This commit is contained in:
James Betker 2022-07-13 21:26:59 -06:00
commit def70cd444
6 changed files with 84 additions and 20 deletions

View File

@ -367,12 +367,14 @@ class ResBlock(nn.Module):
up=False,
down=False,
kernel_size=3,
checkpointing_enabled=True,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.checkpointing_enabled = checkpointing_enabled
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential(
@ -417,9 +419,12 @@ class ResBlock(nn.Module):
:param x: an [N x C x ...] Tensor of features.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, x
)
if self.checkpointing_enabled:
return checkpoint(
self._forward, x
)
else:
return self._forward(x)
def _forward(self, x):
if self.updown:
@ -1017,4 +1022,4 @@ def gather_2d(input, index):
result = result.squeeze()
if b == 1:
result = result.unsqueeze(0)
return result
return result

View File

@ -31,9 +31,9 @@ class ConditioningEncoder(nn.Module):
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=2,
do_checkpointing=do_checkpointing
)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x):
h = self.init(x).permute(0,2,1)
@ -122,4 +122,4 @@ def test_ar():
if __name__ == '__main__':
test_ar()
test_ar()

View File

@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from trainer.networks import register_model
from utils.util import opt_get, ceil_multiple, print_network
class ResEncoder16x(nn.Module):
def __init__(self,
spec_dim,
hidden_dim,
embedding_dim,
checkpointing_enabled=True,
):
super().__init__()
attn = []
def edim(m):
dd = min(spec_dim + m * 128, hidden_dim)
return ceil_multiple(dd, 8)
self.downsampler = nn.Sequential(
ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
self.encoder = nn.Sequential(
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
nn.GroupNorm(8, hidden_dim),
nn.SiLU(),
nn.Conv1d(hidden_dim, embedding_dim, 1),
nn.Tanh(),
)
def forward(self, x):
h = self.downsampler(x)
h = self.encoder(h)
return h

View File

@ -14,6 +14,7 @@ class UpperEncoder(nn.Module):
spec_dim,
hidden_dim,
embedding_dim,
checkpointing_enabled=True,
):
super().__init__()
attn = []
@ -21,18 +22,18 @@ class UpperEncoder(nn.Module):
dd = min(spec_dim + m * 128, hidden_dim)
return ceil_multiple(dd, 8)
self.downsampler = nn.Sequential(
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True),
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True),
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True),
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1),
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True))
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
self.encoder = nn.Sequential(
AttentionBlock(hidden_dim, 4, do_activation=True),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
AttentionBlock(hidden_dim, 4, do_activation=True),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
AttentionBlock(hidden_dim, 4, do_activation=True),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
nn.GroupNorm(8, hidden_dim),
nn.SiLU(),
nn.Conv1d(hidden_dim, embedding_dim, 1),
@ -45,6 +46,8 @@ class UpperEncoder(nn.Module):
return h
class GptMusicLower(nn.Module):
def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}):
super().__init__()
@ -170,4 +173,4 @@ def test_lower():
if __name__ == '__main__':
test_lower()
test_lower()

View File

@ -456,15 +456,14 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
class TransformerDiffusionWithCheaterLatent(nn.Module):
def __init__(self, freeze_encoder_until=None, **kwargs):
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs):
super().__init__()
self.internal_step = 0
self.freeze_encoder_until = freeze_encoder_until
self.diff = TransformerDiffusion(**kwargs)
self.encoder = UpperEncoder(256, 1024, 256)
self.encoder = self.encoder.eval()
from models.audio.music.encoders import ResEncoder16x
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
unused_parameters = []

View File

@ -98,11 +98,24 @@ class RandomAudioCropInjector(Injector):
self.max_crop_sz = opt['max_crop_size']
self.lengths_key = opt['lengths_key']
self.crop_start_key = opt['crop_start_key']
self.rand_buffer_ptr=9999
self.rand_buffer_sz=5000
def forward(self, state):
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
inp = state[self.input]
if torch.distributed.get_world_size() > 1:
# All processes should agree, otherwise all processes wait to process max_crop_sz (effectively). But agreeing too often
# is expensive, so agree on a "chunk" at a time.
if self.rand_buffer_ptr >= self.rand_buffer_sz:
self.rand_buffer = torch.randint(self.min_crop_sz, self.max_crop_sz, (self.rand_buffer_sz,), dtype=torch.long, device=inp.device)
torch.distributed.broadcast(self.rand_buffer, 0)
self.rand_buffer_ptr = 0
crop_sz = self.rand_buffer[self.rand_buffer_ptr]
self.rand_buffer_ptr += 1
else:
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
if self.lengths_key is not None:
lens = state[self.lengths_key]
len = torch.min(lens)
@ -445,4 +458,4 @@ class MusicCheaterArInjector(Injector):
self.needs_move = False
with torch.no_grad():
latents = self.cheater_ar(codes, cond, return_latent=True)
return {self.output: latents}
return {self.output: latents}