Mods to support vqvae in audio mode (1d)

This commit is contained in:
James Betker 2021-07-20 08:36:46 -06:00
parent 5584cfcc7a
commit d81386c1be
5 changed files with 104 additions and 36 deletions

View File

@ -116,17 +116,24 @@ class TextMelCollate():
'input_lengths': input_lengths,
'padded_mel': mel_padded,
'padded_gate': gate_padded,
'output_lengths': output_lengths
'output_lengths': output_lengths,
}
if __name__ == '__main__':
params = {
'mode': 'nv_tacotron',
'path': 'E:\\4k6k\\datasets\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
}
from data import create_dataset
ds = create_dataset(params)
j = ds[0]
print(j)
i = 0
m = []
for b in ds:
m.append(b)
i += 1
if i > 9999:
break
m=torch.stack(m)
print(m.mean(), m.std())

View File

@ -83,14 +83,14 @@ class Quantize(nn.Module):
class ResBlock(nn.Module):
def __init__(self, in_channel, channel):
def __init__(self, in_channel, channel, conv_module):
super().__init__()
self.conv = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(in_channel, channel, 3, padding=1),
conv_module(in_channel, channel, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, in_channel, 1),
conv_module(channel, in_channel, 1),
)
def forward(self, input):
@ -101,27 +101,27 @@ class ResBlock(nn.Module):
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, conv_module):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
conv_module(channel // 2, channel, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
conv_module(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
conv_module(channel // 2, channel, 3, padding=1),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(ResBlock(channel, n_res_channel, conv_module))
blocks.append(nn.ReLU(inplace=True))
@ -133,23 +133,23 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, conv_module, conv_transpose_module
):
super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
blocks = [conv_module(in_channel, channel, 3, padding=1)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(ResBlock(channel, n_res_channel, conv_module))
blocks.append(nn.ReLU(inplace=True))
if stride == 4:
blocks.extend(
[
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
conv_transpose_module(channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
conv_transpose_module(
channel // 2, out_channel, 4, stride=2, padding=1
),
]
@ -157,7 +157,7 @@ class Decoder(nn.Module):
elif stride == 2:
blocks.append(
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
conv_transpose_module(channel, out_channel, 4, stride=2, padding=1)
)
self.blocks = nn.Sequential(*blocks)
@ -175,20 +175,25 @@ class VQVAE(nn.Module):
n_res_channel=32,
codebook_dim=64,
codebook_size=512,
conv_module=nn.Conv2d,
conv_transpose_module=nn.ConvTranspose2d,
decay=0.99,
):
super().__init__()
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.unsqueeze_channels = in_channel == -1
in_channel = abs(in_channel)
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module, conv_transpose_module=conv_transpose_module
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_conv_b = conv_module(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size)
self.upsample_t = nn.ConvTranspose2d(
self.upsample_t = conv_transpose_module(
codebook_dim, codebook_dim, 4, stride=2, padding=1
)
self.dec = Decoder(
@ -198,11 +203,17 @@ class VQVAE(nn.Module):
n_res_block,
n_res_channel,
stride=4,
conv_module=conv_module,
conv_transpose_module=conv_transpose_module
)
def forward(self, input):
if self.unsqueeze_channels:
input = input.unsqueeze(1)
quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b)
if self.unsqueeze_channels:
dec = dec.squeeze(1)
return dec, diff
@ -210,17 +221,17 @@ class VQVAE(nn.Module):
enc_b = checkpoint(self.enc_b, input)
enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
@ -234,9 +245,9 @@ class VQVAE(nn.Module):
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
dec = self.decode(quant_t, quant_b)
@ -247,6 +258,19 @@ class VQVAE(nn.Module):
def register_vqvae(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
vq = VQVAE(**kw)
if distributed.is_initialized() and distributed.get_world_size() > 1:
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
return vq
@register_model
def register_vqvae_audio(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
kw['conv_module'] = nn.Conv1d
kw['conv_transpose_module'] = nn.ConvTranspose1d
vq = VQVAE(**kw)
return vq
if __name__ == '__main__':
model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
res=model(torch.randn(1,224))
print(res[0].shape)

View File

@ -28,3 +28,5 @@ pytorch_fid==0.1.1
inflect==0.2.5
librosa==0.6.0
Unidecode==1.0.22
tgt == 1.4.4
pyworld == 0.2.10

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -419,3 +419,38 @@ class NoiseInjector(Injector):
def forward(self, state):
shape = (state[self.input].shape[0],) + self.shape
return {self.output: torch.randn(shape, device=state[self.input].device)}
# Incorporates the specified dimension into the batch dimension.
class DecomposeDimensionInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.dim = opt['dim']
assert self.dim != 0 # Cannot decompose the batch dimension
def forward(self, state):
inp = state[self.input]
dims = list(range(len(inp.shape))) # Looks like [0,1,2,3]
shape = list(inp.shape)
del dims[self.dim]
del shape[self.dim]
return {self.output: inp.permute([self.dim] + dims).reshape((-1,) + tuple(shape[1:]))}
# Performs normalization across fixed constants.
class NormalizeInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.shift = opt['shift']
self.scale = opt['scale']
def forward(self, state):
inp = state[self.input]
out = (inp - self.shift) / self.scale
return {self.output: out}
if __name__ == '__main__':
inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)
print(inj({'x':torch.randn(10,3,64,64)})['y'].shape)