some stuff
This commit is contained in:
parent
e23c322089
commit
15831b2576
|
@ -367,12 +367,14 @@ class ResBlock(nn.Module):
|
|||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
checkpointing_enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.checkpointing_enabled = checkpointing_enabled
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
|
@ -417,9 +419,12 @@ class ResBlock(nn.Module):
|
|||
:param x: an [N x C x ...] Tensor of features.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, x
|
||||
)
|
||||
if self.checkpointing_enabled:
|
||||
return checkpoint(
|
||||
self._forward, x
|
||||
)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
if self.updown:
|
||||
|
@ -1017,4 +1022,4 @@ def gather_2d(input, index):
|
|||
result = result.squeeze()
|
||||
if b == 1:
|
||||
result = result.unsqueeze(0)
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -31,9 +31,9 @@ class ConditioningEncoder(nn.Module):
|
|||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=2,
|
||||
do_checkpointing=do_checkpointing
|
||||
)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
def forward(self, x):
|
||||
h = self.init(x).permute(0,2,1)
|
||||
|
@ -122,4 +122,4 @@ def test_ar():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ar()
|
||||
test_ar()
|
||||
|
|
44
codes/models/audio/music/encoders.py
Normal file
44
codes/models/audio/music/encoders.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, ceil_multiple, print_network
|
||||
|
||||
|
||||
class ResEncoder16x(nn.Module):
|
||||
def __init__(self,
|
||||
spec_dim,
|
||||
hidden_dim,
|
||||
embedding_dim,
|
||||
checkpointing_enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
def edim(m):
|
||||
dd = min(spec_dim + m * 128, hidden_dim)
|
||||
return ceil_multiple(dd, 8)
|
||||
self.downsampler = nn.Sequential(
|
||||
ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
|
||||
self.encoder = nn.Sequential(
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
nn.GroupNorm(8, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(hidden_dim, embedding_dim, 1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.downsampler(x)
|
||||
h = self.encoder(h)
|
||||
return h
|
|
@ -14,6 +14,7 @@ class UpperEncoder(nn.Module):
|
|||
spec_dim,
|
||||
hidden_dim,
|
||||
embedding_dim,
|
||||
checkpointing_enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
|
@ -21,18 +22,18 @@ class UpperEncoder(nn.Module):
|
|||
dd = min(spec_dim + m * 128, hidden_dim)
|
||||
return ceil_multiple(dd, 8)
|
||||
self.downsampler = nn.Sequential(
|
||||
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1),
|
||||
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True))
|
||||
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
|
||||
self.encoder = nn.Sequential(
|
||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
nn.GroupNorm(8, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(hidden_dim, embedding_dim, 1),
|
||||
|
@ -45,6 +46,8 @@ class UpperEncoder(nn.Module):
|
|||
return h
|
||||
|
||||
|
||||
|
||||
|
||||
class GptMusicLower(nn.Module):
|
||||
def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}):
|
||||
super().__init__()
|
||||
|
@ -170,4 +173,4 @@ def test_lower():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lower()
|
||||
test_lower()
|
||||
|
|
|
@ -456,15 +456,14 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
|
|||
if hasattr(p, 'grad') and p.grad is not None:
|
||||
p.grad *= .2
|
||||
|
||||
|
||||
class TransformerDiffusionWithCheaterLatent(nn.Module):
|
||||
def __init__(self, freeze_encoder_until=None, **kwargs):
|
||||
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs):
|
||||
super().__init__()
|
||||
self.internal_step = 0
|
||||
self.freeze_encoder_until = freeze_encoder_until
|
||||
self.diff = TransformerDiffusion(**kwargs)
|
||||
self.encoder = UpperEncoder(256, 1024, 256)
|
||||
self.encoder = self.encoder.eval()
|
||||
from models.audio.music.encoders import ResEncoder16x
|
||||
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||
unused_parameters = []
|
||||
|
|
|
@ -98,11 +98,24 @@ class RandomAudioCropInjector(Injector):
|
|||
self.max_crop_sz = opt['max_crop_size']
|
||||
self.lengths_key = opt['lengths_key']
|
||||
self.crop_start_key = opt['crop_start_key']
|
||||
self.rand_buffer_ptr=9999
|
||||
self.rand_buffer_sz=5000
|
||||
|
||||
|
||||
def forward(self, state):
|
||||
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
|
||||
inp = state[self.input]
|
||||
if torch.distributed.get_world_size() > 1:
|
||||
# All processes should agree, otherwise all processes wait to process max_crop_sz (effectively). But agreeing too often
|
||||
# is expensive, so agree on a "chunk" at a time.
|
||||
if self.rand_buffer_ptr >= self.rand_buffer_sz:
|
||||
self.rand_buffer = torch.randint(self.min_crop_sz, self.max_crop_sz, (self.rand_buffer_sz,), dtype=torch.long, device=inp.device)
|
||||
torch.distributed.broadcast(self.rand_buffer, 0)
|
||||
self.rand_buffer_ptr = 0
|
||||
crop_sz = self.rand_buffer[self.rand_buffer_ptr]
|
||||
self.rand_buffer_ptr += 1
|
||||
else:
|
||||
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
|
||||
if self.lengths_key is not None:
|
||||
lens = state[self.lengths_key]
|
||||
len = torch.min(lens)
|
||||
|
@ -445,4 +458,4 @@ class MusicCheaterArInjector(Injector):
|
|||
self.needs_move = False
|
||||
with torch.no_grad():
|
||||
latents = self.cheater_ar(codes, cond, return_latent=True)
|
||||
return {self.output: latents}
|
||||
return {self.output: latents}
|
||||
|
|
Loading…
Reference in New Issue
Block a user