good ole ddp..

pull/2/head
James Betker 2022-07-18 17:13:45 +07:00
parent cf57c352c8
commit c959e530cb
1 changed files with 2 additions and 17 deletions

@ -1,40 +1,25 @@
import itertools
from random import randrange
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
from models.audio.music.music_quantizer2 import MusicQuantizer2
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, \
FeedForward
from trainer.networks import register_model
from utils.util import checkpoint, print_network
from utils.util import checkpoint
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
return torch.cat(h, dim=-1)
class SubBlock(nn.Module):
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout):
super().__init__()