forked from mrq/DL-Art-School
good ole ddp..
This commit is contained in:
parent
cf57c352c8
commit
c959e530cb
|
@ -1,40 +1,25 @@
|
||||||
import itertools
|
import itertools
|
||||||
from random import randrange
|
from random import randrange
|
||||||
from time import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
|
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
|
||||||
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
|
|
||||||
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
|
||||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepBlock
|
||||||
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, \
|
|
||||||
FeedForward
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, print_network
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def is_latent(t):
|
def is_latent(t):
|
||||||
return t.dtype == torch.float
|
return t.dtype == torch.float
|
||||||
|
|
||||||
|
|
||||||
def is_sequence(t):
|
def is_sequence(t):
|
||||||
return t.dtype == torch.long
|
return t.dtype == torch.long
|
||||||
|
|
||||||
|
|
||||||
class MultiGroupEmbedding(nn.Module):
|
|
||||||
def __init__(self, tokens, groups, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
|
||||||
return torch.cat(h, dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class SubBlock(nn.Module):
|
class SubBlock(nn.Module):
|
||||||
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout):
|
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user