This commit is contained in:
James Betker 2021-10-03 15:20:50 -06:00
parent 09f373e3b1
commit 66f99a159c

View File

@ -9,6 +9,7 @@ from einops import rearrange
from torch import einsum from torch import einsum
from models.diffusion.unet_diffusion import AttentionBlock from models.diffusion.unet_diffusion import AttentionBlock
from models.stylegan.stylegan2_rosinality import EqualLinear
from models.vqvae.vqvae import Quantize from models.vqvae.vqvae import Quantize
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
@ -28,34 +29,89 @@ def eval_decorator(fn):
return inner return inner
class ModulatedConv1d(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
initial_weight_factor=1,
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
fan_in = in_channel * kernel_size ** 2
self.scale = initial_weight_factor / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size)
)
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def forward(self, input, style):
batch, in_channel, d = input.shape
style = self.modulation(style).view(batch, 1, in_channel, 1)
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size
)
input = input.view(1, batch * in_channel, d)
out = F.conv1d(input, weight, padding=self.padding, groups=batch)
_, _, d = out.shape
out = out.view(batch, self.out_channel, d)
return out
class ChannelAttentionModule(nn.Module): class ChannelAttentionModule(nn.Module):
def __init__(self, channels_in, channels_out, attention_dim, layers, num_heads=1): def __init__(self, channels_in, channels_out, attention_dim, layers, num_heads=1):
super().__init__() super().__init__()
self.channels_in = channels_in
self.channels_out = channels_out self.channels_out = channels_out
# This is the bypass. It performs the same computation, without attention. It is responsible for stabilizing
# training early on by being more optimizable.
self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1)
self.positional_embeddings = nn.Embedding(channels_out, attention_dim) self.positional_embeddings = nn.Embedding(channels_out, attention_dim)
self.first_layer = nn.Conv1d(channels_in, attention_dim, kernel_size=1) self.first_layer = ModulatedConv1d(1, attention_dim, kernel_size=1, style_dim=channels_in, initial_weight_factor=.1)
self.layers = nn.Sequential(*[AttentionBlock(attention_dim, num_heads=num_heads) for _ in range(layers)]) self.layers = nn.Sequential(*[AttentionBlock(attention_dim, num_heads=num_heads) for _ in range(layers)])
self.post_attn_layer = nn.Conv1d(attention_dim, 1, kernel_size=1) self.post_attn_layer = nn.Conv1d(attention_dim, 1, kernel_size=1)
# This is the bypass. It performs the same computation, without attention. Stabilizes the network early on.
self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1)
self.mix = nn.Parameter(torch.zeros((1,)))
def forward(self, inp): def forward(self, inp):
# Collapse structural dimension of x bypass = self.bypass(inp)
b, c, w = inp.shape emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=inp.device)).permute(1,0).unsqueeze(0)
x = inp.permute(0,2,1).reshape(b*w, c).unsqueeze(-1).repeat(1,1,self.channels_out)
x = self.first_layer(x) b, c, w = bypass.shape
emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=x.device)).permute(1,0).unsqueeze(0) # Reshape bypass so channels become structure and structure becomes part of the batch.
x = bypass.permute(0,2,1).reshape(b*w, c).unsqueeze(1)
# Reshape the input as well so it can be fed into the stylizer.
style = inp.permute(0,2,1).reshape(b*w, self.channels_in)
x = self.first_layer(x, style)
x = emb + x x = emb + x
x = self.layers(x) x = self.layers(x)
x = x - emb # Subtract of emb to further stabilize early training, where the attention layers do nothing.
out = self.post_attn_layer(x).squeeze(1) out = self.post_attn_layer(x).squeeze(1)
out = out.view(b,w,self.channels_out).permute(0,2,1) out = out.view(b,w,self.channels_out).permute(0,2,1)
bypass = self.bypass(inp) return bypass + out
return bypass * (1-self.mix) + out * self.mix
class ResBlock(nn.Module): class ResBlock(nn.Module):
@ -189,7 +245,7 @@ class DiscreteVAE(nn.Module):
return images return images
def get_debug_values(self, step, __): def get_debug_values(self, step, __):
dbg = {'decoder_attn_bypass': self.decoder[-1].mix.item()} dbg = {}
if self.record_codes: if self.record_codes:
# Report annealing schedule # Report annealing schedule
dbg.update({'histogram_codes': self.codes}) dbg.update({'histogram_codes': self.codes})