Alterations to support VQVAE on mel spectrograms

This commit is contained in:
James Betker 2021-08-01 07:54:21 -06:00
parent 965f6e6b52
commit 2814307eee
3 changed files with 17 additions and 3 deletions

View File

@ -271,6 +271,6 @@ def register_vqvae_audio(opt_net, opt):
if __name__ == '__main__':
model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
res=model(torch.randn(1,224))
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
res=model(torch.randn(1,80,224))
print(res[0].shape)

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -10,6 +10,20 @@ from utils.util import opt_get
from utils.weight_scheduler import get_scheduler_for_opt
class PadInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.multiple = opt['multiple']
def forward(self, state):
ldim = state[self.input].shape[-1]
mod = self.multiple-(ldim % self.multiple)
t = state[self.input]
if mod != 0:
t = torch.nn.functional.pad(t, (0, mod))
return {self.output: t}
class SqueezeInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)