diff --git a/codes/models/audio/music/transformer_diffusion3.py b/codes/models/audio/music/transformer_diffusion3.py index a2f8c371..43641303 100644 --- a/codes/models/audio/music/transformer_diffusion3.py +++ b/codes/models/audio/music/transformer_diffusion3.py @@ -1,10 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torchsummary from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock +from models.diffusion.unet_diffusion import TimestepBlock from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding from trainer.networks import register_model from utils.util import checkpoint, print_network diff --git a/codes/models/audio/music/transformer_diffusion4.py b/codes/models/audio/music/transformer_diffusion4.py index 8fec716e..014661a4 100644 --- a/codes/models/audio/music/transformer_diffusion4.py +++ b/codes/models/audio/music/transformer_diffusion4.py @@ -1,10 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torchsummary from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock +from models.diffusion.unet_diffusion import TimestepBlock from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding from trainer.networks import register_model from utils.util import checkpoint, print_network @@ -220,6 +219,5 @@ if __name__ == '__main__': model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536, num_layers=16, in_groups=8) torch.save(model, 'sample.pth') print_network(model) - #torchsummary.torchsummary.summary(model, clip, ts, aligned_sequence, cond, return_code_pred=True) o = model(clip, ts, aligned_sequence, cond) diff --git a/codes/scripts/audio/preparation/combine_phonetic_and_text.py b/codes/scripts/audio/preparation/combine_phonetic_and_text.py index cc0d344d..587e4f78 100644 --- a/codes/scripts/audio/preparation/combine_phonetic_and_text.py +++ b/codes/scripts/audio/preparation/combine_phonetic_and_text.py @@ -1,7 +1,7 @@ import os if __name__ == '__main__': - basepath = 'Y:/clips/books2' + basepath = 'Y:/clips/podcasts-0' english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv') if not os.path.exists(english_file): diff --git a/codes/utils/util.py b/codes/utils/util.py index a5b83ead..b3e52397 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -15,6 +15,8 @@ import cv2 import torch import torchaudio from audio2numpy import open_audio +from torch import nn +from torch.nn.parallel import DistributedDataParallel from torchvision.utils import make_grid from shutil import get_terminal_size import scp @@ -617,3 +619,17 @@ def load_wav_to_torch(full_path): else: raise NotImplemented(f"Provided data dtype not supported: {data.dtype}") return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) + + +def get_network_description(network): + """Get the string and total parameters of the network""" + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + return str(network), sum(map(lambda x: x.numel(), network.parameters())) + + +def print_network(net, name='some network'): + s, n = get_network_description(net) + net_struc_str = '{}'.format(net.__class__.__name__) + print('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n)) + print(s)