diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 66db3b21..89f09d6c 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -1,40 +1,25 @@ import itertools from random import randrange -from time import time import torch import torch.nn as nn import torch.nn.functional as F from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock -from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, \ - FeedForward from trainer.networks import register_model -from utils.util import checkpoint, print_network +from utils.util import checkpoint def is_latent(t): return t.dtype == torch.float + def is_sequence(t): return t.dtype == torch.long -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - class SubBlock(nn.Module): def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout): super().__init__()