This commit is contained in:
James Betker 2022-05-28 10:59:03 -06:00
parent 0d3b831cf9
commit b6b4f10e1b
4 changed files with 19 additions and 6 deletions

View File

@ -1,10 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
from models.diffusion.unet_diffusion import TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint, print_network

View File

@ -1,10 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
from models.diffusion.unet_diffusion import TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint, print_network
@ -220,6 +219,5 @@ if __name__ == '__main__':
model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536, num_layers=16, in_groups=8)
torch.save(model, 'sample.pth')
print_network(model)
#torchsummary.torchsummary.summary(model, clip, ts, aligned_sequence, cond, return_code_pred=True)
o = model(clip, ts, aligned_sequence, cond)

View File

@ -1,7 +1,7 @@
import os
if __name__ == '__main__':
basepath = 'Y:/clips/books2'
basepath = 'Y:/clips/podcasts-0'
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
if not os.path.exists(english_file):

View File

@ -15,6 +15,8 @@ import cv2
import torch
import torchaudio
from audio2numpy import open_audio
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torchvision.utils import make_grid
from shutil import get_terminal_size
import scp
@ -617,3 +619,17 @@ def load_wav_to_torch(full_path):
else:
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
def get_network_description(network):
"""Get the string and total parameters of the network"""
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
return str(network), sum(map(lambda x: x.numel(), network.parameters()))
def print_network(net, name='some network'):
s, n = get_network_description(net)
net_struc_str = '{}'.format(net.__class__.__name__)
print('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
print(s)