From 66f99a159cf4f06719d1c7275ae5ce43ffddb981 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 3 Oct 2021 15:20:50 -0600 Subject: [PATCH] Rev2 --- .../dvae_channel_attention.py | 84 +++++++++++++++---- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py b/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py index d61e36a6..e4d9266f 100644 --- a/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py +++ b/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py @@ -9,6 +9,7 @@ from einops import rearrange from torch import einsum from models.diffusion.unet_diffusion import AttentionBlock +from models.stylegan.stylegan2_rosinality import EqualLinear from models.vqvae.vqvae import Quantize from trainer.networks import register_model from utils.util import opt_get @@ -28,34 +29,89 @@ def eval_decorator(fn): return inner +class ModulatedConv1d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + initial_weight_factor=1, + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + + fan_in = in_channel * kernel_size ** 2 + self.scale = initial_weight_factor / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def forward(self, input, style): + batch, in_channel, d = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size + ) + + input = input.view(1, batch * in_channel, d) + out = F.conv1d(input, weight, padding=self.padding, groups=batch) + _, _, d = out.shape + out = out.view(batch, self.out_channel, d) + + return out + + class ChannelAttentionModule(nn.Module): def __init__(self, channels_in, channels_out, attention_dim, layers, num_heads=1): super().__init__() + self.channels_in = channels_in self.channels_out = channels_out + # This is the bypass. It performs the same computation, without attention. It is responsible for stabilizing + # training early on by being more optimizable. + self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1) + self.positional_embeddings = nn.Embedding(channels_out, attention_dim) - self.first_layer = nn.Conv1d(channels_in, attention_dim, kernel_size=1) + self.first_layer = ModulatedConv1d(1, attention_dim, kernel_size=1, style_dim=channels_in, initial_weight_factor=.1) self.layers = nn.Sequential(*[AttentionBlock(attention_dim, num_heads=num_heads) for _ in range(layers)]) self.post_attn_layer = nn.Conv1d(attention_dim, 1, kernel_size=1) - # This is the bypass. It performs the same computation, without attention. Stabilizes the network early on. - self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1) - self.mix = nn.Parameter(torch.zeros((1,))) - def forward(self, inp): - # Collapse structural dimension of x - b, c, w = inp.shape - x = inp.permute(0,2,1).reshape(b*w, c).unsqueeze(-1).repeat(1,1,self.channels_out) - x = self.first_layer(x) - emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=x.device)).permute(1,0).unsqueeze(0) + bypass = self.bypass(inp) + emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=inp.device)).permute(1,0).unsqueeze(0) + + b, c, w = bypass.shape + # Reshape bypass so channels become structure and structure becomes part of the batch. + x = bypass.permute(0,2,1).reshape(b*w, c).unsqueeze(1) + # Reshape the input as well so it can be fed into the stylizer. + style = inp.permute(0,2,1).reshape(b*w, self.channels_in) + x = self.first_layer(x, style) x = emb + x x = self.layers(x) + x = x - emb # Subtract of emb to further stabilize early training, where the attention layers do nothing. out = self.post_attn_layer(x).squeeze(1) out = out.view(b,w,self.channels_out).permute(0,2,1) - bypass = self.bypass(inp) - - return bypass * (1-self.mix) + out * self.mix + return bypass + out class ResBlock(nn.Module): @@ -189,7 +245,7 @@ class DiscreteVAE(nn.Module): return images def get_debug_values(self, step, __): - dbg = {'decoder_attn_bypass': self.decoder[-1].mix.item()} + dbg = {} if self.record_codes: # Report annealing schedule dbg.update({'histogram_codes': self.codes})