Rev2
This commit is contained in:
parent
09f373e3b1
commit
66f99a159c
|
@ -9,6 +9,7 @@ from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from models.diffusion.unet_diffusion import AttentionBlock
|
from models.diffusion.unet_diffusion import AttentionBlock
|
||||||
|
from models.stylegan.stylegan2_rosinality import EqualLinear
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
@ -28,34 +29,89 @@ def eval_decorator(fn):
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedConv1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channel,
|
||||||
|
out_channel,
|
||||||
|
kernel_size,
|
||||||
|
style_dim,
|
||||||
|
demodulate=True,
|
||||||
|
initial_weight_factor=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.eps = 1e-8
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.in_channel = in_channel
|
||||||
|
self.out_channel = out_channel
|
||||||
|
|
||||||
|
fan_in = in_channel * kernel_size ** 2
|
||||||
|
self.scale = initial_weight_factor / math.sqrt(fan_in)
|
||||||
|
self.padding = kernel_size // 2
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.randn(1, out_channel, in_channel, kernel_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
||||||
|
|
||||||
|
self.demodulate = demodulate
|
||||||
|
|
||||||
|
def forward(self, input, style):
|
||||||
|
batch, in_channel, d = input.shape
|
||||||
|
|
||||||
|
style = self.modulation(style).view(batch, 1, in_channel, 1)
|
||||||
|
weight = self.scale * self.weight * style
|
||||||
|
|
||||||
|
if self.demodulate:
|
||||||
|
demod = torch.rsqrt(weight.pow(2).sum([2, 3]) + 1e-8)
|
||||||
|
weight = weight * demod.view(batch, self.out_channel, 1, 1)
|
||||||
|
|
||||||
|
weight = weight.view(
|
||||||
|
batch * self.out_channel, in_channel, self.kernel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
input = input.view(1, batch * in_channel, d)
|
||||||
|
out = F.conv1d(input, weight, padding=self.padding, groups=batch)
|
||||||
|
_, _, d = out.shape
|
||||||
|
out = out.view(batch, self.out_channel, d)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttentionModule(nn.Module):
|
class ChannelAttentionModule(nn.Module):
|
||||||
def __init__(self, channels_in, channels_out, attention_dim, layers, num_heads=1):
|
def __init__(self, channels_in, channels_out, attention_dim, layers, num_heads=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.channels_in = channels_in
|
||||||
self.channels_out = channels_out
|
self.channels_out = channels_out
|
||||||
|
|
||||||
|
# This is the bypass. It performs the same computation, without attention. It is responsible for stabilizing
|
||||||
|
# training early on by being more optimizable.
|
||||||
|
self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1)
|
||||||
|
|
||||||
self.positional_embeddings = nn.Embedding(channels_out, attention_dim)
|
self.positional_embeddings = nn.Embedding(channels_out, attention_dim)
|
||||||
self.first_layer = nn.Conv1d(channels_in, attention_dim, kernel_size=1)
|
self.first_layer = ModulatedConv1d(1, attention_dim, kernel_size=1, style_dim=channels_in, initial_weight_factor=.1)
|
||||||
self.layers = nn.Sequential(*[AttentionBlock(attention_dim, num_heads=num_heads) for _ in range(layers)])
|
self.layers = nn.Sequential(*[AttentionBlock(attention_dim, num_heads=num_heads) for _ in range(layers)])
|
||||||
self.post_attn_layer = nn.Conv1d(attention_dim, 1, kernel_size=1)
|
self.post_attn_layer = nn.Conv1d(attention_dim, 1, kernel_size=1)
|
||||||
|
|
||||||
# This is the bypass. It performs the same computation, without attention. Stabilizes the network early on.
|
|
||||||
self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1)
|
|
||||||
self.mix = nn.Parameter(torch.zeros((1,)))
|
|
||||||
|
|
||||||
def forward(self, inp):
|
def forward(self, inp):
|
||||||
# Collapse structural dimension of x
|
bypass = self.bypass(inp)
|
||||||
b, c, w = inp.shape
|
emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=inp.device)).permute(1,0).unsqueeze(0)
|
||||||
x = inp.permute(0,2,1).reshape(b*w, c).unsqueeze(-1).repeat(1,1,self.channels_out)
|
|
||||||
x = self.first_layer(x)
|
b, c, w = bypass.shape
|
||||||
emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=x.device)).permute(1,0).unsqueeze(0)
|
# Reshape bypass so channels become structure and structure becomes part of the batch.
|
||||||
|
x = bypass.permute(0,2,1).reshape(b*w, c).unsqueeze(1)
|
||||||
|
# Reshape the input as well so it can be fed into the stylizer.
|
||||||
|
style = inp.permute(0,2,1).reshape(b*w, self.channels_in)
|
||||||
|
x = self.first_layer(x, style)
|
||||||
x = emb + x
|
x = emb + x
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
x = x - emb # Subtract of emb to further stabilize early training, where the attention layers do nothing.
|
||||||
out = self.post_attn_layer(x).squeeze(1)
|
out = self.post_attn_layer(x).squeeze(1)
|
||||||
out = out.view(b,w,self.channels_out).permute(0,2,1)
|
out = out.view(b,w,self.channels_out).permute(0,2,1)
|
||||||
|
|
||||||
bypass = self.bypass(inp)
|
return bypass + out
|
||||||
|
|
||||||
return bypass * (1-self.mix) + out * self.mix
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
|
@ -189,7 +245,7 @@ class DiscreteVAE(nn.Module):
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
dbg = {'decoder_attn_bypass': self.decoder[-1].mix.item()}
|
dbg = {}
|
||||||
if self.record_codes:
|
if self.record_codes:
|
||||||
# Report annealing schedule
|
# Report annealing schedule
|
||||||
dbg.update({'histogram_codes': self.codes})
|
dbg.update({'histogram_codes': self.codes})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user