...
This commit is contained in:
parent
0d3b831cf9
commit
b6b4f10e1b
|
@ -1,10 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchsummary
|
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepBlock
|
||||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, print_network
|
from utils.util import checkpoint, print_network
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchsummary
|
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepBlock
|
||||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, print_network
|
from utils.util import checkpoint, print_network
|
||||||
|
@ -220,6 +219,5 @@ if __name__ == '__main__':
|
||||||
model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536, num_layers=16, in_groups=8)
|
model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536, num_layers=16, in_groups=8)
|
||||||
torch.save(model, 'sample.pth')
|
torch.save(model, 'sample.pth')
|
||||||
print_network(model)
|
print_network(model)
|
||||||
#torchsummary.torchsummary.summary(model, clip, ts, aligned_sequence, cond, return_code_pred=True)
|
|
||||||
o = model(clip, ts, aligned_sequence, cond)
|
o = model(clip, ts, aligned_sequence, cond)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
basepath = 'Y:/clips/books2'
|
basepath = 'Y:/clips/podcasts-0'
|
||||||
|
|
||||||
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
|
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
|
||||||
if not os.path.exists(english_file):
|
if not os.path.exists(english_file):
|
||||||
|
|
|
@ -15,6 +15,8 @@ import cv2
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from audio2numpy import open_audio
|
from audio2numpy import open_audio
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
import scp
|
import scp
|
||||||
|
@ -617,3 +619,17 @@ def load_wav_to_torch(full_path):
|
||||||
else:
|
else:
|
||||||
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
|
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
|
||||||
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def get_network_description(network):
|
||||||
|
"""Get the string and total parameters of the network"""
|
||||||
|
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
|
network = network.module
|
||||||
|
return str(network), sum(map(lambda x: x.numel(), network.parameters()))
|
||||||
|
|
||||||
|
|
||||||
|
def print_network(net, name='some network'):
|
||||||
|
s, n = get_network_description(net)
|
||||||
|
net_struc_str = '{}'.format(net.__class__.__name__)
|
||||||
|
print('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
|
||||||
|
print(s)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user