Mods to support vqvae in audio mode (1d)

This commit is contained in:
James Betker 2021-07-20 08:36:46 -06:00
parent 5584cfcc7a
commit d81386c1be
5 changed files with 104 additions and 36 deletions

View File

@ -116,17 +116,24 @@ class TextMelCollate():
'input_lengths': input_lengths, 'input_lengths': input_lengths,
'padded_mel': mel_padded, 'padded_mel': mel_padded,
'padded_gate': gate_padded, 'padded_gate': gate_padded,
'output_lengths': output_lengths 'output_lengths': output_lengths,
} }
if __name__ == '__main__': if __name__ == '__main__':
params = { params = {
'mode': 'nv_tacotron', 'mode': 'nv_tacotron',
'path': 'E:\\4k6k\\datasets\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', 'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
} }
from data import create_dataset from data import create_dataset
ds = create_dataset(params) ds = create_dataset(params)
j = ds[0] i = 0
print(j) m = []
for b in ds:
m.append(b)
i += 1
if i > 9999:
break
m=torch.stack(m)
print(m.mean(), m.std())

View File

@ -83,14 +83,14 @@ class Quantize(nn.Module):
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__(self, in_channel, channel): def __init__(self, in_channel, channel, conv_module):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(in_channel, channel, 3, padding=1), conv_module(in_channel, channel, 3, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(channel, in_channel, 1), conv_module(channel, in_channel, 1),
) )
def forward(self, input): def forward(self, input):
@ -101,27 +101,27 @@ class ResBlock(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, conv_module):
super().__init__() super().__init__()
if stride == 4: if stride == 4:
blocks = [ blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), conv_module(channel // 2, channel, 4, stride=2, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1), conv_module(channel, channel, 3, padding=1),
] ]
elif stride == 2: elif stride == 2:
blocks = [ blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1), conv_module(channel // 2, channel, 3, padding=1),
] ]
for i in range(n_res_block): for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel)) blocks.append(ResBlock(channel, n_res_channel, conv_module))
blocks.append(nn.ReLU(inplace=True)) blocks.append(nn.ReLU(inplace=True))
@ -133,23 +133,23 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, conv_module, conv_transpose_module
): ):
super().__init__() super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] blocks = [conv_module(in_channel, channel, 3, padding=1)]
for i in range(n_res_block): for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel)) blocks.append(ResBlock(channel, n_res_channel, conv_module))
blocks.append(nn.ReLU(inplace=True)) blocks.append(nn.ReLU(inplace=True))
if stride == 4: if stride == 4:
blocks.extend( blocks.extend(
[ [
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1), conv_transpose_module(channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.ConvTranspose2d( conv_transpose_module(
channel // 2, out_channel, 4, stride=2, padding=1 channel // 2, out_channel, 4, stride=2, padding=1
), ),
] ]
@ -157,7 +157,7 @@ class Decoder(nn.Module):
elif stride == 2: elif stride == 2:
blocks.append( blocks.append(
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1) conv_transpose_module(channel, out_channel, 4, stride=2, padding=1)
) )
self.blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
@ -175,20 +175,25 @@ class VQVAE(nn.Module):
n_res_channel=32, n_res_channel=32,
codebook_dim=64, codebook_dim=64,
codebook_size=512, codebook_size=512,
conv_module=nn.Conv2d,
conv_transpose_module=nn.ConvTranspose2d,
decay=0.99, decay=0.99,
): ):
super().__init__() super().__init__()
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4) self.unsqueeze_channels = in_channel == -1
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) in_channel = abs(in_channel)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size) self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder( self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2 codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module, conv_transpose_module=conv_transpose_module
) )
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) self.quantize_conv_b = conv_module(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size) self.quantize_b = Quantize(codebook_dim, codebook_size)
self.upsample_t = nn.ConvTranspose2d( self.upsample_t = conv_transpose_module(
codebook_dim, codebook_dim, 4, stride=2, padding=1 codebook_dim, codebook_dim, 4, stride=2, padding=1
) )
self.dec = Decoder( self.dec = Decoder(
@ -198,11 +203,17 @@ class VQVAE(nn.Module):
n_res_block, n_res_block,
n_res_channel, n_res_channel,
stride=4, stride=4,
conv_module=conv_module,
conv_transpose_module=conv_transpose_module
) )
def forward(self, input): def forward(self, input):
if self.unsqueeze_channels:
input = input.unsqueeze(1)
quant_t, quant_b, diff, _, _ = self.encode(input) quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b) dec = self.decode(quant_t, quant_b)
if self.unsqueeze_channels:
dec = dec.squeeze(1)
return dec, diff return dec, diff
@ -210,17 +221,17 @@ class VQVAE(nn.Module):
enc_b = checkpoint(self.enc_b, input) enc_b = checkpoint(self.enc_b, input)
enc_t = checkpoint(self.enc_t, enc_b) enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
quant_t, diff_t, id_t = self.quantize_t(quant_t) quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2) quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
diff_t = diff_t.unsqueeze(0) diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t) dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1) enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
quant_b, diff_b, id_b = self.quantize_b(quant_b) quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2) quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
diff_b = diff_b.unsqueeze(0) diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b return quant_t, quant_b, diff_t + diff_b, id_t, id_b
@ -234,9 +245,9 @@ class VQVAE(nn.Module):
def decode_code(self, code_t, code_b): def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t) quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2) quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
quant_b = self.quantize_b.embed_code(code_b) quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2) quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
dec = self.decode(quant_t, quant_b) dec = self.decode(quant_t, quant_b)
@ -247,6 +258,19 @@ class VQVAE(nn.Module):
def register_vqvae(opt_net, opt): def register_vqvae(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {}) kw = opt_get(opt_net, ['kwargs'], {})
vq = VQVAE(**kw) vq = VQVAE(**kw)
if distributed.is_initialized() and distributed.get_world_size() > 1:
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
return vq return vq
@register_model
def register_vqvae_audio(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
kw['conv_module'] = nn.Conv1d
kw['conv_transpose_module'] = nn.ConvTranspose1d
vq = VQVAE(**kw)
return vq
if __name__ == '__main__':
model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
res=model(torch.randn(1,224))
print(res[0].shape)

View File

@ -28,3 +28,5 @@ pytorch_fid==0.1.1
inflect==0.2.5 inflect==0.2.5
librosa==0.6.0 librosa==0.6.0
Unidecode==1.0.22 Unidecode==1.0.22
tgt == 1.4.4
pyworld == 0.2.10

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -419,3 +419,38 @@ class NoiseInjector(Injector):
def forward(self, state): def forward(self, state):
shape = (state[self.input].shape[0],) + self.shape shape = (state[self.input].shape[0],) + self.shape
return {self.output: torch.randn(shape, device=state[self.input].device)} return {self.output: torch.randn(shape, device=state[self.input].device)}
# Incorporates the specified dimension into the batch dimension.
class DecomposeDimensionInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.dim = opt['dim']
assert self.dim != 0 # Cannot decompose the batch dimension
def forward(self, state):
inp = state[self.input]
dims = list(range(len(inp.shape))) # Looks like [0,1,2,3]
shape = list(inp.shape)
del dims[self.dim]
del shape[self.dim]
return {self.output: inp.permute([self.dim] + dims).reshape((-1,) + tuple(shape[1:]))}
# Performs normalization across fixed constants.
class NormalizeInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.shift = opt['shift']
self.scale = opt['scale']
def forward(self, state):
inp = state[self.input]
out = (inp - self.shift) / self.scale
return {self.output: out}
if __name__ == '__main__':
inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)
print(inj({'x':torch.randn(10,3,64,64)})['y'].shape)