|
|
|
@ -1,40 +1,25 @@
|
|
|
|
|
import itertools
|
|
|
|
|
from random import randrange
|
|
|
|
|
from time import time
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
|
|
|
|
|
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
|
|
|
|
|
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
|
|
|
|
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
|
|
|
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
|
|
|
|
from models.diffusion.unet_diffusion import TimestepBlock
|
|
|
|
|
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, \
|
|
|
|
|
FeedForward
|
|
|
|
|
from trainer.networks import register_model
|
|
|
|
|
from utils.util import checkpoint, print_network
|
|
|
|
|
from utils.util import checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_latent(t):
|
|
|
|
|
return t.dtype == torch.float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_sequence(t):
|
|
|
|
|
return t.dtype == torch.long
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiGroupEmbedding(nn.Module):
|
|
|
|
|
def __init__(self, tokens, groups, dim):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
|
|
|
|
return torch.cat(h, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SubBlock(nn.Module):
|
|
|
|
|
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout):
|
|
|
|
|
super().__init__()
|
|
|
|
|