...
This commit is contained in:
parent
0d3b831cf9
commit
b6b4f10e1b
|
@ -1,10 +1,9 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchsummary
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchsummary
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network
|
||||
|
@ -220,6 +219,5 @@ if __name__ == '__main__':
|
|||
model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536, num_layers=16, in_groups=8)
|
||||
torch.save(model, 'sample.pth')
|
||||
print_network(model)
|
||||
#torchsummary.torchsummary.summary(model, clip, ts, aligned_sequence, cond, return_code_pred=True)
|
||||
o = model(clip, ts, aligned_sequence, cond)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
basepath = 'Y:/clips/books2'
|
||||
basepath = 'Y:/clips/podcasts-0'
|
||||
|
||||
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
|
||||
if not os.path.exists(english_file):
|
||||
|
|
|
@ -15,6 +15,8 @@ import cv2
|
|||
import torch
|
||||
import torchaudio
|
||||
from audio2numpy import open_audio
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torchvision.utils import make_grid
|
||||
from shutil import get_terminal_size
|
||||
import scp
|
||||
|
@ -617,3 +619,17 @@ def load_wav_to_torch(full_path):
|
|||
else:
|
||||
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
|
||||
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
||||
|
||||
|
||||
def get_network_description(network):
|
||||
"""Get the string and total parameters of the network"""
|
||||
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||
network = network.module
|
||||
return str(network), sum(map(lambda x: x.numel(), network.parameters()))
|
||||
|
||||
|
||||
def print_network(net, name='some network'):
|
||||
s, n = get_network_description(net)
|
||||
net_struc_str = '{}'.format(net.__class__.__name__)
|
||||
print('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
|
||||
print(s)
|
||||
|
|
Loading…
Reference in New Issue
Block a user