From 2814307eeea170cff163d1c36eab93cf8f6f2cc9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 1 Aug 2021 07:54:21 -0600 Subject: [PATCH] Alterations to support VQVAE on mel spectrograms --- codes/models/vqvae/vqvae.py | 4 ++-- codes/train.py | 2 +- codes/trainer/injectors/base_injectors.py | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index 837af1fd..012864f3 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -271,6 +271,6 @@ def register_vqvae_audio(opt_net, opt): if __name__ == '__main__': - model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) - res=model(torch.randn(1,224)) + model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) + res=model(torch.randn(1,80,224)) print(res[0].shape) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index da87349b..35e6c88f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 12fc126b..0a4ffebc 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -10,6 +10,20 @@ from utils.util import opt_get from utils.weight_scheduler import get_scheduler_for_opt +class PadInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.multiple = opt['multiple'] + + def forward(self, state): + ldim = state[self.input].shape[-1] + mod = self.multiple-(ldim % self.multiple) + t = state[self.input] + if mod != 0: + t = torch.nn.functional.pad(t, (0, mod)) + return {self.output: t} + + class SqueezeInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env)