good ole ddp..

This commit is contained in:
James Betker 2022-07-18 17:13:45 -06:00
parent cf57c352c8
commit c959e530cb

View File

@ -1,40 +1,25 @@
import itertools import itertools
from random import randrange from random import randrange
from time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
from models.audio.music.music_quantizer2 import MusicQuantizer2
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock from models.diffusion.unet_diffusion import TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, \
FeedForward
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint, print_network from utils.util import checkpoint
def is_latent(t): def is_latent(t):
return t.dtype == torch.float return t.dtype == torch.float
def is_sequence(t): def is_sequence(t):
return t.dtype == torch.long return t.dtype == torch.long
class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
return torch.cat(h, dim=-1)
class SubBlock(nn.Module): class SubBlock(nn.Module):
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout): def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout):
super().__init__() super().__init__()