forked from mrq/DL-Art-School
Mods to support vqvae in audio mode (1d)
This commit is contained in:
parent
5584cfcc7a
commit
d81386c1be
|
@ -116,17 +116,24 @@ class TextMelCollate():
|
|||
'input_lengths': input_lengths,
|
||||
'padded_mel': mel_padded,
|
||||
'padded_gate': gate_padded,
|
||||
'output_lengths': output_lengths
|
||||
'output_lengths': output_lengths,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'nv_tacotron',
|
||||
'path': 'E:\\4k6k\\datasets\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||
|
||||
}
|
||||
from data import create_dataset
|
||||
ds = create_dataset(params)
|
||||
j = ds[0]
|
||||
print(j)
|
||||
i = 0
|
||||
m = []
|
||||
for b in ds:
|
||||
m.append(b)
|
||||
i += 1
|
||||
if i > 9999:
|
||||
break
|
||||
m=torch.stack(m)
|
||||
print(m.mean(), m.std())
|
||||
|
|
|
@ -83,14 +83,14 @@ class Quantize(nn.Module):
|
|||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, channel):
|
||||
def __init__(self, in_channel, channel, conv_module):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channel, channel, 3, padding=1),
|
||||
conv_module(in_channel, channel, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel, in_channel, 1),
|
||||
conv_module(channel, in_channel, 1),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -101,27 +101,27 @@ class ResBlock(nn.Module):
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, conv_module):
|
||||
super().__init__()
|
||||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||
conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
|
||||
conv_module(channel // 2, channel, 4, stride=2, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel, channel, 3, padding=1),
|
||||
conv_module(channel, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||
conv_module(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
||||
conv_module(channel // 2, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
blocks.append(ResBlock(channel, n_res_channel, conv_module))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
|
@ -133,23 +133,23 @@ class Encoder(nn.Module):
|
|||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, conv_module, conv_transpose_module
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
|
||||
blocks = [conv_module(in_channel, channel, 3, padding=1)]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
blocks.append(ResBlock(channel, n_res_channel, conv_module))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
|
||||
conv_transpose_module(channel, channel // 2, 4, stride=2, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(
|
||||
conv_transpose_module(
|
||||
channel // 2, out_channel, 4, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
|
@ -157,7 +157,7 @@ class Decoder(nn.Module):
|
|||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
|
||||
conv_transpose_module(channel, out_channel, 4, stride=2, padding=1)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
@ -175,20 +175,25 @@ class VQVAE(nn.Module):
|
|||
n_res_channel=32,
|
||||
codebook_dim=64,
|
||||
codebook_size=512,
|
||||
conv_module=nn.Conv2d,
|
||||
conv_transpose_module=nn.ConvTranspose2d,
|
||||
decay=0.99,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
|
||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||
self.unsqueeze_channels = in_channel == -1
|
||||
in_channel = abs(in_channel)
|
||||
|
||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
|
||||
self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
|
||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||
self.dec_t = Decoder(
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module, conv_transpose_module=conv_transpose_module
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_conv_b = conv_module(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||
self.upsample_t = nn.ConvTranspose2d(
|
||||
self.upsample_t = conv_transpose_module(
|
||||
codebook_dim, codebook_dim, 4, stride=2, padding=1
|
||||
)
|
||||
self.dec = Decoder(
|
||||
|
@ -198,11 +203,17 @@ class VQVAE(nn.Module):
|
|||
n_res_block,
|
||||
n_res_channel,
|
||||
stride=4,
|
||||
conv_module=conv_module,
|
||||
conv_transpose_module=conv_transpose_module
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
if self.unsqueeze_channels:
|
||||
input = input.unsqueeze(1)
|
||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
if self.unsqueeze_channels:
|
||||
dec = dec.squeeze(1)
|
||||
|
||||
return dec, diff
|
||||
|
||||
|
@ -210,17 +221,17 @@ class VQVAE(nn.Module):
|
|||
enc_b = checkpoint(self.enc_b, input)
|
||||
enc_t = checkpoint(self.enc_t, enc_b)
|
||||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
||||
quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
@ -234,9 +245,9 @@ class VQVAE(nn.Module):
|
|||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
|
@ -247,6 +258,19 @@ class VQVAE(nn.Module):
|
|||
def register_vqvae(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
vq = VQVAE(**kw)
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
|
||||
return vq
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae_audio(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
kw['conv_module'] = nn.Conv1d
|
||||
kw['conv_transpose_module'] = nn.ConvTranspose1d
|
||||
vq = VQVAE(**kw)
|
||||
return vq
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||
res=model(torch.randn(1,224))
|
||||
print(res[0].shape)
|
|
@ -28,3 +28,5 @@ pytorch_fid==0.1.1
|
|||
inflect==0.2.5
|
||||
librosa==0.6.0
|
||||
Unidecode==1.0.22
|
||||
tgt == 1.4.4
|
||||
pyworld == 0.2.10
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -419,3 +419,38 @@ class NoiseInjector(Injector):
|
|||
def forward(self, state):
|
||||
shape = (state[self.input].shape[0],) + self.shape
|
||||
return {self.output: torch.randn(shape, device=state[self.input].device)}
|
||||
|
||||
|
||||
# Incorporates the specified dimension into the batch dimension.
|
||||
class DecomposeDimensionInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.dim = opt['dim']
|
||||
assert self.dim != 0 # Cannot decompose the batch dimension
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
dims = list(range(len(inp.shape))) # Looks like [0,1,2,3]
|
||||
shape = list(inp.shape)
|
||||
del dims[self.dim]
|
||||
del shape[self.dim]
|
||||
return {self.output: inp.permute([self.dim] + dims).reshape((-1,) + tuple(shape[1:]))}
|
||||
|
||||
|
||||
# Performs normalization across fixed constants.
|
||||
class NormalizeInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.shift = opt['shift']
|
||||
self.scale = opt['scale']
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
out = (inp - self.shift) / self.scale
|
||||
return {self.output: out}
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)
|
||||
print(inj({'x':torch.randn(10,3,64,64)})['y'].shape)
|
Loading…
Reference in New Issue
Block a user