forked from mrq/DL-Art-School
Alterations to support VQVAE on mel spectrograms
This commit is contained in:
parent
965f6e6b52
commit
2814307eee
|
@ -271,6 +271,6 @@ def register_vqvae_audio(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||||
res=model(torch.randn(1,224))
|
res=model(torch.randn(1,80,224))
|
||||||
print(res[0].shape)
|
print(res[0].shape)
|
|
@ -300,7 +300,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -10,6 +10,20 @@ from utils.util import opt_get
|
||||||
from utils.weight_scheduler import get_scheduler_for_opt
|
from utils.weight_scheduler import get_scheduler_for_opt
|
||||||
|
|
||||||
|
|
||||||
|
class PadInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.multiple = opt['multiple']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
ldim = state[self.input].shape[-1]
|
||||||
|
mod = self.multiple-(ldim % self.multiple)
|
||||||
|
t = state[self.input]
|
||||||
|
if mod != 0:
|
||||||
|
t = torch.nn.functional.pad(t, (0, mod))
|
||||||
|
return {self.output: t}
|
||||||
|
|
||||||
|
|
||||||
class SqueezeInjector(Injector):
|
class SqueezeInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user