forked from mrq/DL-Art-School
Alterations to support VQVAE on mel spectrograms
This commit is contained in:
parent
965f6e6b52
commit
2814307eee
|
@ -271,6 +271,6 @@ def register_vqvae_audio(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||
res=model(torch.randn(1,224))
|
||||
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||
res=model(torch.randn(1,80,224))
|
||||
print(res[0].shape)
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -10,6 +10,20 @@ from utils.util import opt_get
|
|||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
|
||||
|
||||
class PadInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.multiple = opt['multiple']
|
||||
|
||||
def forward(self, state):
|
||||
ldim = state[self.input].shape[-1]
|
||||
mod = self.multiple-(ldim % self.multiple)
|
||||
t = state[self.input]
|
||||
if mod != 0:
|
||||
t = torch.nn.functional.pad(t, (0, mod))
|
||||
return {self.output: t}
|
||||
|
||||
|
||||
class SqueezeInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
|
|
Loading…
Reference in New Issue
Block a user