forked from mrq/DL-Art-School
Mods to support vqvae in audio mode (1d)
This commit is contained in:
parent
5584cfcc7a
commit
d81386c1be
|
@ -116,17 +116,24 @@ class TextMelCollate():
|
||||||
'input_lengths': input_lengths,
|
'input_lengths': input_lengths,
|
||||||
'padded_mel': mel_padded,
|
'padded_mel': mel_padded,
|
||||||
'padded_gate': gate_padded,
|
'padded_gate': gate_padded,
|
||||||
'output_lengths': output_lengths
|
'output_lengths': output_lengths,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
params = {
|
params = {
|
||||||
'mode': 'nv_tacotron',
|
'mode': 'nv_tacotron',
|
||||||
'path': 'E:\\4k6k\\datasets\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||||
|
|
||||||
}
|
}
|
||||||
from data import create_dataset
|
from data import create_dataset
|
||||||
ds = create_dataset(params)
|
ds = create_dataset(params)
|
||||||
j = ds[0]
|
i = 0
|
||||||
print(j)
|
m = []
|
||||||
|
for b in ds:
|
||||||
|
m.append(b)
|
||||||
|
i += 1
|
||||||
|
if i > 9999:
|
||||||
|
break
|
||||||
|
m=torch.stack(m)
|
||||||
|
print(m.mean(), m.std())
|
||||||
|
|
|
@ -83,14 +83,14 @@ class Quantize(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
def __init__(self, in_channel, channel):
|
def __init__(self, in_channel, channel, conv_module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(in_channel, channel, 3, padding=1),
|
conv_module(in_channel, channel, 3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(channel, in_channel, 1),
|
conv_module(channel, in_channel, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
@ -101,27 +101,27 @@ class ResBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
|
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, conv_module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if stride == 4:
|
if stride == 4:
|
||||||
blocks = [
|
blocks = [
|
||||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
|
conv_module(channel // 2, channel, 4, stride=2, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(channel, channel, 3, padding=1),
|
conv_module(channel, channel, 3, padding=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
elif stride == 2:
|
elif stride == 2:
|
||||||
blocks = [
|
blocks = [
|
||||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
conv_module(channel // 2, channel, 3, padding=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(n_res_block):
|
for i in range(n_res_block):
|
||||||
blocks.append(ResBlock(channel, n_res_channel))
|
blocks.append(ResBlock(channel, n_res_channel, conv_module))
|
||||||
|
|
||||||
blocks.append(nn.ReLU(inplace=True))
|
blocks.append(nn.ReLU(inplace=True))
|
||||||
|
|
||||||
|
@ -133,23 +133,23 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
|
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, conv_module, conv_transpose_module
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
|
blocks = [conv_module(in_channel, channel, 3, padding=1)]
|
||||||
|
|
||||||
for i in range(n_res_block):
|
for i in range(n_res_block):
|
||||||
blocks.append(ResBlock(channel, n_res_channel))
|
blocks.append(ResBlock(channel, n_res_channel, conv_module))
|
||||||
|
|
||||||
blocks.append(nn.ReLU(inplace=True))
|
blocks.append(nn.ReLU(inplace=True))
|
||||||
|
|
||||||
if stride == 4:
|
if stride == 4:
|
||||||
blocks.extend(
|
blocks.extend(
|
||||||
[
|
[
|
||||||
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
|
conv_transpose_module(channel, channel // 2, 4, stride=2, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.ConvTranspose2d(
|
conv_transpose_module(
|
||||||
channel // 2, out_channel, 4, stride=2, padding=1
|
channel // 2, out_channel, 4, stride=2, padding=1
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@ -157,7 +157,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
elif stride == 2:
|
elif stride == 2:
|
||||||
blocks.append(
|
blocks.append(
|
||||||
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
|
conv_transpose_module(channel, out_channel, 4, stride=2, padding=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*blocks)
|
self.blocks = nn.Sequential(*blocks)
|
||||||
|
@ -175,20 +175,25 @@ class VQVAE(nn.Module):
|
||||||
n_res_channel=32,
|
n_res_channel=32,
|
||||||
codebook_dim=64,
|
codebook_dim=64,
|
||||||
codebook_size=512,
|
codebook_size=512,
|
||||||
|
conv_module=nn.Conv2d,
|
||||||
|
conv_transpose_module=nn.ConvTranspose2d,
|
||||||
decay=0.99,
|
decay=0.99,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
|
self.unsqueeze_channels = in_channel == -1
|
||||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
|
in_channel = abs(in_channel)
|
||||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
|
||||||
|
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
|
||||||
|
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
|
||||||
|
self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
|
||||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||||
self.dec_t = Decoder(
|
self.dec_t = Decoder(
|
||||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module, conv_transpose_module=conv_transpose_module
|
||||||
)
|
)
|
||||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
self.quantize_conv_b = conv_module(codebook_dim + channel, codebook_dim, 1)
|
||||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||||
self.upsample_t = nn.ConvTranspose2d(
|
self.upsample_t = conv_transpose_module(
|
||||||
codebook_dim, codebook_dim, 4, stride=2, padding=1
|
codebook_dim, codebook_dim, 4, stride=2, padding=1
|
||||||
)
|
)
|
||||||
self.dec = Decoder(
|
self.dec = Decoder(
|
||||||
|
@ -198,11 +203,17 @@ class VQVAE(nn.Module):
|
||||||
n_res_block,
|
n_res_block,
|
||||||
n_res_channel,
|
n_res_channel,
|
||||||
stride=4,
|
stride=4,
|
||||||
|
conv_module=conv_module,
|
||||||
|
conv_transpose_module=conv_transpose_module
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
if self.unsqueeze_channels:
|
||||||
|
input = input.unsqueeze(1)
|
||||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||||
dec = self.decode(quant_t, quant_b)
|
dec = self.decode(quant_t, quant_b)
|
||||||
|
if self.unsqueeze_channels:
|
||||||
|
dec = dec.squeeze(1)
|
||||||
|
|
||||||
return dec, diff
|
return dec, diff
|
||||||
|
|
||||||
|
@ -210,17 +221,17 @@ class VQVAE(nn.Module):
|
||||||
enc_b = checkpoint(self.enc_b, input)
|
enc_b = checkpoint(self.enc_b, input)
|
||||||
enc_t = checkpoint(self.enc_t, enc_b)
|
enc_t = checkpoint(self.enc_t, enc_b)
|
||||||
|
|
||||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
|
||||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||||
diff_t = diff_t.unsqueeze(0)
|
diff_t = diff_t.unsqueeze(0)
|
||||||
|
|
||||||
dec_t = checkpoint(self.dec_t, quant_t)
|
dec_t = checkpoint(self.dec_t, quant_t)
|
||||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||||
|
|
||||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
|
||||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||||
diff_b = diff_b.unsqueeze(0)
|
diff_b = diff_b.unsqueeze(0)
|
||||||
|
|
||||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||||
|
@ -234,9 +245,9 @@ class VQVAE(nn.Module):
|
||||||
|
|
||||||
def decode_code(self, code_t, code_b):
|
def decode_code(self, code_t, code_b):
|
||||||
quant_t = self.quantize_t.embed_code(code_t)
|
quant_t = self.quantize_t.embed_code(code_t)
|
||||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||||
quant_b = self.quantize_b.embed_code(code_b)
|
quant_b = self.quantize_b.embed_code(code_b)
|
||||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||||
|
|
||||||
dec = self.decode(quant_t, quant_b)
|
dec = self.decode(quant_t, quant_b)
|
||||||
|
|
||||||
|
@ -247,6 +258,19 @@ class VQVAE(nn.Module):
|
||||||
def register_vqvae(opt_net, opt):
|
def register_vqvae(opt_net, opt):
|
||||||
kw = opt_get(opt_net, ['kwargs'], {})
|
kw = opt_get(opt_net, ['kwargs'], {})
|
||||||
vq = VQVAE(**kw)
|
vq = VQVAE(**kw)
|
||||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
|
||||||
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
|
|
||||||
return vq
|
return vq
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_vqvae_audio(opt_net, opt):
|
||||||
|
kw = opt_get(opt_net, ['kwargs'], {})
|
||||||
|
kw['conv_module'] = nn.Conv1d
|
||||||
|
kw['conv_transpose_module'] = nn.ConvTranspose1d
|
||||||
|
vq = VQVAE(**kw)
|
||||||
|
return vq
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||||
|
res=model(torch.randn(1,224))
|
||||||
|
print(res[0].shape)
|
|
@ -28,3 +28,5 @@ pytorch_fid==0.1.1
|
||||||
inflect==0.2.5
|
inflect==0.2.5
|
||||||
librosa==0.6.0
|
librosa==0.6.0
|
||||||
Unidecode==1.0.22
|
Unidecode==1.0.22
|
||||||
|
tgt == 1.4.4
|
||||||
|
pyworld == 0.2.10
|
|
@ -300,7 +300,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -419,3 +419,38 @@ class NoiseInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
shape = (state[self.input].shape[0],) + self.shape
|
shape = (state[self.input].shape[0],) + self.shape
|
||||||
return {self.output: torch.randn(shape, device=state[self.input].device)}
|
return {self.output: torch.randn(shape, device=state[self.input].device)}
|
||||||
|
|
||||||
|
|
||||||
|
# Incorporates the specified dimension into the batch dimension.
|
||||||
|
class DecomposeDimensionInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.dim = opt['dim']
|
||||||
|
assert self.dim != 0 # Cannot decompose the batch dimension
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
inp = state[self.input]
|
||||||
|
dims = list(range(len(inp.shape))) # Looks like [0,1,2,3]
|
||||||
|
shape = list(inp.shape)
|
||||||
|
del dims[self.dim]
|
||||||
|
del shape[self.dim]
|
||||||
|
return {self.output: inp.permute([self.dim] + dims).reshape((-1,) + tuple(shape[1:]))}
|
||||||
|
|
||||||
|
|
||||||
|
# Performs normalization across fixed constants.
|
||||||
|
class NormalizeInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.shift = opt['shift']
|
||||||
|
self.scale = opt['scale']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
inp = state[self.input]
|
||||||
|
out = (inp - self.shift) / self.scale
|
||||||
|
return {self.output: out}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)
|
||||||
|
print(inj({'x':torch.randn(10,3,64,64)})['y'].shape)
|
Loading…
Reference in New Issue
Block a user