forked from mrq/DL-Art-School
some stuff
This commit is contained in:
parent
e23c322089
commit
15831b2576
|
@ -367,12 +367,14 @@ class ResBlock(nn.Module):
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
|
checkpointing_enabled=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
|
self.checkpointing_enabled = checkpointing_enabled
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
|
@ -417,9 +419,12 @@ class ResBlock(nn.Module):
|
||||||
:param x: an [N x C x ...] Tensor of features.
|
:param x: an [N x C x ...] Tensor of features.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
return checkpoint(
|
if self.checkpointing_enabled:
|
||||||
self._forward, x
|
return checkpoint(
|
||||||
)
|
self._forward, x
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
if self.updown:
|
if self.updown:
|
||||||
|
|
|
@ -31,9 +31,9 @@ class ConditioningEncoder(nn.Module):
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
zero_init_branch_output=True,
|
zero_init_branch_output=True,
|
||||||
ff_mult=2,
|
ff_mult=2,
|
||||||
|
do_checkpointing=do_checkpointing
|
||||||
)
|
)
|
||||||
self.dim = embedding_dim
|
self.dim = embedding_dim
|
||||||
self.do_checkpointing = do_checkpointing
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = self.init(x).permute(0,2,1)
|
h = self.init(x).permute(0,2,1)
|
||||||
|
|
44
codes/models/audio/music/encoders.py
Normal file
44
codes/models/audio/music/encoders.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
|
||||||
|
from models.arch_util import AttentionBlock, ResBlock
|
||||||
|
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get, ceil_multiple, print_network
|
||||||
|
|
||||||
|
|
||||||
|
class ResEncoder16x(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
spec_dim,
|
||||||
|
hidden_dim,
|
||||||
|
embedding_dim,
|
||||||
|
checkpointing_enabled=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
attn = []
|
||||||
|
def edim(m):
|
||||||
|
dd = min(spec_dim + m * 128, hidden_dim)
|
||||||
|
return ceil_multiple(dd, 8)
|
||||||
|
self.downsampler = nn.Sequential(
|
||||||
|
ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
|
nn.GroupNorm(8, hidden_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv1d(hidden_dim, embedding_dim, 1),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.downsampler(x)
|
||||||
|
h = self.encoder(h)
|
||||||
|
return h
|
|
@ -14,6 +14,7 @@ class UpperEncoder(nn.Module):
|
||||||
spec_dim,
|
spec_dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
|
checkpointing_enabled=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn = []
|
attn = []
|
||||||
|
@ -21,18 +22,18 @@ class UpperEncoder(nn.Module):
|
||||||
dd = min(spec_dim + m * 128, hidden_dim)
|
dd = min(spec_dim + m * 128, hidden_dim)
|
||||||
return ceil_multiple(dd, 8)
|
return ceil_multiple(dd, 8)
|
||||||
self.downsampler = nn.Sequential(
|
self.downsampler = nn.Sequential(
|
||||||
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True),
|
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||||
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True),
|
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True),
|
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1),
|
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True))
|
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
|
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
|
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
|
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||||
nn.GroupNorm(8, hidden_dim),
|
nn.GroupNorm(8, hidden_dim),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Conv1d(hidden_dim, embedding_dim, 1),
|
nn.Conv1d(hidden_dim, embedding_dim, 1),
|
||||||
|
@ -45,6 +46,8 @@ class UpperEncoder(nn.Module):
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GptMusicLower(nn.Module):
|
class GptMusicLower(nn.Module):
|
||||||
def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}):
|
def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -456,15 +456,14 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
|
||||||
if hasattr(p, 'grad') and p.grad is not None:
|
if hasattr(p, 'grad') and p.grad is not None:
|
||||||
p.grad *= .2
|
p.grad *= .2
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithCheaterLatent(nn.Module):
|
class TransformerDiffusionWithCheaterLatent(nn.Module):
|
||||||
def __init__(self, freeze_encoder_until=None, **kwargs):
|
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
self.freeze_encoder_until = freeze_encoder_until
|
self.freeze_encoder_until = freeze_encoder_until
|
||||||
self.diff = TransformerDiffusion(**kwargs)
|
self.diff = TransformerDiffusion(**kwargs)
|
||||||
self.encoder = UpperEncoder(256, 1024, 256)
|
from models.audio.music.encoders import ResEncoder16x
|
||||||
self.encoder = self.encoder.eval()
|
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||||
unused_parameters = []
|
unused_parameters = []
|
||||||
|
|
|
@ -98,11 +98,24 @@ class RandomAudioCropInjector(Injector):
|
||||||
self.max_crop_sz = opt['max_crop_size']
|
self.max_crop_sz = opt['max_crop_size']
|
||||||
self.lengths_key = opt['lengths_key']
|
self.lengths_key = opt['lengths_key']
|
||||||
self.crop_start_key = opt['crop_start_key']
|
self.crop_start_key = opt['crop_start_key']
|
||||||
|
self.rand_buffer_ptr=9999
|
||||||
|
self.rand_buffer_sz=5000
|
||||||
|
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
|
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
|
||||||
inp = state[self.input]
|
inp = state[self.input]
|
||||||
|
if torch.distributed.get_world_size() > 1:
|
||||||
|
# All processes should agree, otherwise all processes wait to process max_crop_sz (effectively). But agreeing too often
|
||||||
|
# is expensive, so agree on a "chunk" at a time.
|
||||||
|
if self.rand_buffer_ptr >= self.rand_buffer_sz:
|
||||||
|
self.rand_buffer = torch.randint(self.min_crop_sz, self.max_crop_sz, (self.rand_buffer_sz,), dtype=torch.long, device=inp.device)
|
||||||
|
torch.distributed.broadcast(self.rand_buffer, 0)
|
||||||
|
self.rand_buffer_ptr = 0
|
||||||
|
crop_sz = self.rand_buffer[self.rand_buffer_ptr]
|
||||||
|
self.rand_buffer_ptr += 1
|
||||||
|
else:
|
||||||
|
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
|
||||||
if self.lengths_key is not None:
|
if self.lengths_key is not None:
|
||||||
lens = state[self.lengths_key]
|
lens = state[self.lengths_key]
|
||||||
len = torch.min(lens)
|
len = torch.min(lens)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user