From d81386c1bee6b480597c232a7f64e19326cd9ac4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 20 Jul 2021 08:36:46 -0600 Subject: [PATCH] Mods to support vqvae in audio mode (1d) --- codes/data/audio/nv_tacotron_dataset.py | 15 ++-- codes/models/vqvae/vqvae.py | 84 +++++++++++++++-------- codes/requirements.txt | 4 +- codes/train.py | 2 +- codes/trainer/injectors/base_injectors.py | 35 ++++++++++ 5 files changed, 104 insertions(+), 36 deletions(-) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 2545d78b..dbd010e7 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -116,17 +116,24 @@ class TextMelCollate(): 'input_lengths': input_lengths, 'padded_mel': mel_padded, 'padded_gate': gate_padded, - 'output_lengths': output_lengths + 'output_lengths': output_lengths, } if __name__ == '__main__': params = { 'mode': 'nv_tacotron', - 'path': 'E:\\4k6k\\datasets\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', + 'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', } from data import create_dataset ds = create_dataset(params) - j = ds[0] - print(j) \ No newline at end of file + i = 0 + m = [] + for b in ds: + m.append(b) + i += 1 + if i > 9999: + break + m=torch.stack(m) + print(m.mean(), m.std()) diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index f2b98ac4..837af1fd 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -83,14 +83,14 @@ class Quantize(nn.Module): class ResBlock(nn.Module): - def __init__(self, in_channel, channel): + def __init__(self, in_channel, channel, conv_module): super().__init__() self.conv = nn.Sequential( nn.ReLU(inplace=True), - nn.Conv2d(in_channel, channel, 3, padding=1), + conv_module(in_channel, channel, 3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(channel, in_channel, 1), + conv_module(channel, in_channel, 1), ) def forward(self, input): @@ -101,27 +101,27 @@ class ResBlock(nn.Module): class Encoder(nn.Module): - def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): + def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, conv_module): super().__init__() if stride == 4: blocks = [ - nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), + conv_module(in_channel, channel // 2, 4, stride=2, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), + conv_module(channel // 2, channel, 4, stride=2, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(channel, channel, 3, padding=1), + conv_module(channel, channel, 3, padding=1), ] elif stride == 2: blocks = [ - nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), + conv_module(in_channel, channel // 2, 4, stride=2, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 3, padding=1), + conv_module(channel // 2, channel, 3, padding=1), ] for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) + blocks.append(ResBlock(channel, n_res_channel, conv_module)) blocks.append(nn.ReLU(inplace=True)) @@ -133,23 +133,23 @@ class Encoder(nn.Module): class Decoder(nn.Module): def __init__( - self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride + self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, conv_module, conv_transpose_module ): super().__init__() - blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] + blocks = [conv_module(in_channel, channel, 3, padding=1)] for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) + blocks.append(ResBlock(channel, n_res_channel, conv_module)) blocks.append(nn.ReLU(inplace=True)) if stride == 4: blocks.extend( [ - nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1), + conv_transpose_module(channel, channel // 2, 4, stride=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d( + conv_transpose_module( channel // 2, out_channel, 4, stride=2, padding=1 ), ] @@ -157,7 +157,7 @@ class Decoder(nn.Module): elif stride == 2: blocks.append( - nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1) + conv_transpose_module(channel, out_channel, 4, stride=2, padding=1) ) self.blocks = nn.Sequential(*blocks) @@ -175,20 +175,25 @@ class VQVAE(nn.Module): n_res_channel=32, codebook_dim=64, codebook_size=512, + conv_module=nn.Conv2d, + conv_transpose_module=nn.ConvTranspose2d, decay=0.99, ): super().__init__() - self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) - self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) + self.unsqueeze_channels = in_channel == -1 + in_channel = abs(in_channel) + + self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module) + self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module) + self.quantize_conv_t = conv_module(channel, codebook_dim, 1) self.quantize_t = Quantize(codebook_dim, codebook_size) self.dec_t = Decoder( - codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2 + codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module, conv_transpose_module=conv_transpose_module ) - self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) + self.quantize_conv_b = conv_module(codebook_dim + channel, codebook_dim, 1) self.quantize_b = Quantize(codebook_dim, codebook_size) - self.upsample_t = nn.ConvTranspose2d( + self.upsample_t = conv_transpose_module( codebook_dim, codebook_dim, 4, stride=2, padding=1 ) self.dec = Decoder( @@ -198,11 +203,17 @@ class VQVAE(nn.Module): n_res_block, n_res_channel, stride=4, + conv_module=conv_module, + conv_transpose_module=conv_transpose_module ) def forward(self, input): + if self.unsqueeze_channels: + input = input.unsqueeze(1) quant_t, quant_b, diff, _, _ = self.encode(input) dec = self.decode(quant_t, quant_b) + if self.unsqueeze_channels: + dec = dec.squeeze(1) return dec, diff @@ -210,17 +221,17 @@ class VQVAE(nn.Module): enc_b = checkpoint(self.enc_b, input) enc_t = checkpoint(self.enc_t, enc_b) - quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) + quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1)) quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute(0, 3, 1, 2) + quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) diff_t = diff_t.unsqueeze(0) dec_t = checkpoint(self.dec_t, quant_t) enc_b = torch.cat([dec_t, enc_b], 1) - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) + quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1)) quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute(0, 3, 1, 2) + quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) diff_b = diff_b.unsqueeze(0) return quant_t, quant_b, diff_t + diff_b, id_t, id_b @@ -234,9 +245,9 @@ class VQVAE(nn.Module): def decode_code(self, code_t, code_b): quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute(0, 3, 1, 2) + quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute(0, 3, 1, 2) + quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) dec = self.decode(quant_t, quant_b) @@ -247,6 +258,19 @@ class VQVAE(nn.Module): def register_vqvae(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) vq = VQVAE(**kw) - if distributed.is_initialized() and distributed.get_world_size() > 1: - vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq) return vq + + +@register_model +def register_vqvae_audio(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + kw['conv_module'] = nn.Conv1d + kw['conv_transpose_module'] = nn.ConvTranspose1d + vq = VQVAE(**kw) + return vq + + +if __name__ == '__main__': + model = VQVAE(in_channel=-1, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) + res=model(torch.randn(1,224)) + print(res[0].shape) \ No newline at end of file diff --git a/codes/requirements.txt b/codes/requirements.txt index 5ddbb501..ecf91ac4 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -27,4 +27,6 @@ pytorch_fid==0.1.1 # For audio generation stuff inflect==0.2.5 librosa==0.6.0 -Unidecode==1.0.22 \ No newline at end of file +Unidecode==1.0.22 +tgt == 1.4.4 +pyworld == 0.2.10 \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index a2202343..35e6c88f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tacotron2_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 221181ee..ce53f972 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -419,3 +419,38 @@ class NoiseInjector(Injector): def forward(self, state): shape = (state[self.input].shape[0],) + self.shape return {self.output: torch.randn(shape, device=state[self.input].device)} + + +# Incorporates the specified dimension into the batch dimension. +class DecomposeDimensionInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.dim = opt['dim'] + assert self.dim != 0 # Cannot decompose the batch dimension + + def forward(self, state): + inp = state[self.input] + dims = list(range(len(inp.shape))) # Looks like [0,1,2,3] + shape = list(inp.shape) + del dims[self.dim] + del shape[self.dim] + return {self.output: inp.permute([self.dim] + dims).reshape((-1,) + tuple(shape[1:]))} + + +# Performs normalization across fixed constants. +class NormalizeInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.shift = opt['shift'] + self.scale = opt['scale'] + + def forward(self, state): + inp = state[self.input] + out = (inp - self.shift) / self.scale + return {self.output: out} + + + +if __name__ == '__main__': + inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None) + print(inj({'x':torch.randn(10,3,64,64)})['y'].shape) \ No newline at end of file