From 97d7cbbc34a69078d8bb2f0189c97e50b6fbaa85 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 23 Jul 2021 10:58:14 -0600 Subject: [PATCH] Additional work for audio xformer (which doesnt really do a great job) --- codes/models/vqvae/vqvae_audio_xformer.py | 149 ++++++++++++++++++++++ codes/train.py | 2 +- codes/trainer/injectors/base_injectors.py | 7 +- 3 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 codes/models/vqvae/vqvae_audio_xformer.py diff --git a/codes/models/vqvae/vqvae_audio_xformer.py b/codes/models/vqvae/vqvae_audio_xformer.py new file mode 100644 index 00000000..8597bdd5 --- /dev/null +++ b/codes/models/vqvae/vqvae_audio_xformer.py @@ -0,0 +1,149 @@ +import math +from typing import Optional + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +import torch.distributed as distributed + +from models.vqvae.vqvae import Quantize +from trainer.networks import register_model +from utils.util import checkpoint, opt_get + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + layer_norm_eps=1e-5, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True, + **factory_kwargs) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + self.norm1 = nn.BatchNorm1d(d_model) + self.norm2 = nn.BatchNorm1d(d_model) + + self.activation = nn.ReLU() + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + src2 + src = self.norm1(src.permute(0,2,1)).permute(0,2,1) + src2 = self.linear2(self.activation(self.linear1(src))) + src = src + src2 + src = self.norm2(src.permute(0,2,1)).permute(0,2,1) + return src + + +class Encoder(nn.Module): + def __init__(self, in_channel, channel, output_breadth, num_layers=8, compression_factor=8): + super().__init__() + + self.compression_factor = compression_factor + self.pre_conv_stack = nn.Sequential(nn.Conv1d(in_channel, channel//4, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(channel//4, channel//2, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.Conv1d(channel//2, channel//2, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(channel//2, channel, kernel_size=3, stride=2, padding=1)) + self.norm1 = nn.BatchNorm1d(channel) + self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth//4) + self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers) + + def forward(self, input): + x = self.norm1(self.pre_conv_stack(input)).permute(0,2,1) + x = self.positional_embeddings(x) + x = self.encode(x) + return x[:,:input.shape[2]//self.compression_factor,:] + + +class Decoder(nn.Module): + def __init__( + self, in_channel, out_channel, channel, output_breadth, num_layers=6 + ): + super().__init__() + + self.initial_conv = nn.Conv1d(in_channel, channel, kernel_size=1) + self.expand = output_breadth + self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth) + self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers) + self.final_conv_stack = nn.Sequential(nn.Conv1d(channel, channel, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(channel, out_channel, kernel_size=3, padding=1)) + + def forward(self, input): + x = self.initial_conv(input.permute(0,2,1)).permute(0,2,1) + x = nn.functional.pad(x, (0,0,0, self.expand-input.shape[1])) + x = self.positional_embeddings(x) + x = self.encode(x).permute(0,2,1) + return self.final_conv_stack(x) + + +class VQVAE(nn.Module): + def __init__( + self, + data_channels=1, + channel=256, + codebook_dim=256, + codebook_size=512, + breadth=80, + ): + super().__init__() + + self.enc = Encoder(data_channels, channel, breadth) + self.quantize_dense = nn.Linear(channel, codebook_dim) + self.quantize = Quantize(codebook_dim, codebook_size) + self.dec = Decoder(codebook_dim, data_channels, channel, breadth) + + def forward(self, input): + input = input.unsqueeze(1) + quant, diff, _ = self.encode(input) + dec = checkpoint(self.dec, quant) + dec = dec.squeeze(1) + return dec, diff + + def encode(self, input): + enc = checkpoint(self.enc, input) + quant = self.quantize_dense(enc) + quant, diff, id = self.quantize(quant) + diff = diff.unsqueeze(0) + return quant, diff, id + + +@register_model +def register_vqvae_xform_audio(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + vq = VQVAE(**kw) + return vq + + +if __name__ == '__main__': + model = VQVAE() + res=model(torch.randn(4,80)) + print(res[0].shape) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 35e6c88f..35e57a2d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_xform_audio_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 9500a278..db191866 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -426,6 +426,7 @@ class DecomposeDimensionInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) self.dim = opt['dim'] + self.cutoff_dim = opt_get(opt, ['cutoff_dim'], -1) assert self.dim != 0 # Cannot decompose the batch dimension def forward(self, state): @@ -440,7 +441,11 @@ class DecomposeDimensionInjector(Injector): rev_permute = list(range(len(inp.shape)))[1:] # Looks like [1,2,3] rev_permute = rev_permute[:self.dim] + [0] + (rev_permute[self.dim:] if self.dim < len(rev_permute) else []) - return {self.output: inp.permute([self.dim] + dims).reshape((-1,) + tuple(shape[1:])), + out = inp.permute([self.dim] + dims).reshape((-1,) + tuple(shape[1:])) + if self.cutoff_dim > -1: + out = out[:self.cutoff_dim] + + return {self.output: out, f'{self.output}_reverse_shape': rev_shape, f'{self.output}_reverse_permute': rev_permute}