Rev2
This commit is contained in:
parent
09f373e3b1
commit
66f99a159c
|
@ -9,6 +9,7 @@ from einops import rearrange
|
|||
from torch import einsum
|
||||
|
||||
from models.diffusion.unet_diffusion import AttentionBlock
|
||||
from models.stylegan.stylegan2_rosinality import EqualLinear
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
@ -28,34 +29,89 @@ def eval_decorator(fn):
|
|||
return inner
|
||||
|
||||
|
||||
class ModulatedConv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
demodulate=True,
|
||||
initial_weight_factor=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-8
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
|
||||
fan_in = in_channel * kernel_size ** 2
|
||||
self.scale = initial_weight_factor / math.sqrt(fan_in)
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channel, in_channel, kernel_size)
|
||||
)
|
||||
|
||||
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
||||
|
||||
self.demodulate = demodulate
|
||||
|
||||
def forward(self, input, style):
|
||||
batch, in_channel, d = input.shape
|
||||
|
||||
style = self.modulation(style).view(batch, 1, in_channel, 1)
|
||||
weight = self.scale * self.weight * style
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3]) + 1e-8)
|
||||
weight = weight * demod.view(batch, self.out_channel, 1, 1)
|
||||
|
||||
weight = weight.view(
|
||||
batch * self.out_channel, in_channel, self.kernel_size
|
||||
)
|
||||
|
||||
input = input.view(1, batch * in_channel, d)
|
||||
out = F.conv1d(input, weight, padding=self.padding, groups=batch)
|
||||
_, _, d = out.shape
|
||||
out = out.view(batch, self.out_channel, d)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ChannelAttentionModule(nn.Module):
|
||||
def __init__(self, channels_in, channels_out, attention_dim, layers, num_heads=1):
|
||||
super().__init__()
|
||||
self.channels_in = channels_in
|
||||
self.channels_out = channels_out
|
||||
|
||||
# This is the bypass. It performs the same computation, without attention. It is responsible for stabilizing
|
||||
# training early on by being more optimizable.
|
||||
self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1)
|
||||
|
||||
self.positional_embeddings = nn.Embedding(channels_out, attention_dim)
|
||||
self.first_layer = nn.Conv1d(channels_in, attention_dim, kernel_size=1)
|
||||
self.first_layer = ModulatedConv1d(1, attention_dim, kernel_size=1, style_dim=channels_in, initial_weight_factor=.1)
|
||||
self.layers = nn.Sequential(*[AttentionBlock(attention_dim, num_heads=num_heads) for _ in range(layers)])
|
||||
self.post_attn_layer = nn.Conv1d(attention_dim, 1, kernel_size=1)
|
||||
|
||||
# This is the bypass. It performs the same computation, without attention. Stabilizes the network early on.
|
||||
self.bypass = nn.Conv1d(channels_in, channels_out, kernel_size=1)
|
||||
self.mix = nn.Parameter(torch.zeros((1,)))
|
||||
|
||||
def forward(self, inp):
|
||||
# Collapse structural dimension of x
|
||||
b, c, w = inp.shape
|
||||
x = inp.permute(0,2,1).reshape(b*w, c).unsqueeze(-1).repeat(1,1,self.channels_out)
|
||||
x = self.first_layer(x)
|
||||
emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=x.device)).permute(1,0).unsqueeze(0)
|
||||
bypass = self.bypass(inp)
|
||||
emb = self.positional_embeddings(torch.arange(0, self.channels_out, device=inp.device)).permute(1,0).unsqueeze(0)
|
||||
|
||||
b, c, w = bypass.shape
|
||||
# Reshape bypass so channels become structure and structure becomes part of the batch.
|
||||
x = bypass.permute(0,2,1).reshape(b*w, c).unsqueeze(1)
|
||||
# Reshape the input as well so it can be fed into the stylizer.
|
||||
style = inp.permute(0,2,1).reshape(b*w, self.channels_in)
|
||||
x = self.first_layer(x, style)
|
||||
x = emb + x
|
||||
x = self.layers(x)
|
||||
x = x - emb # Subtract of emb to further stabilize early training, where the attention layers do nothing.
|
||||
out = self.post_attn_layer(x).squeeze(1)
|
||||
out = out.view(b,w,self.channels_out).permute(0,2,1)
|
||||
|
||||
bypass = self.bypass(inp)
|
||||
|
||||
return bypass * (1-self.mix) + out * self.mix
|
||||
return bypass + out
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
@ -189,7 +245,7 @@ class DiscreteVAE(nn.Module):
|
|||
return images
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
dbg = {'decoder_attn_bypass': self.decoder[-1].mix.item()}
|
||||
dbg = {}
|
||||
if self.record_codes:
|
||||
# Report annealing schedule
|
||||
dbg.update({'histogram_codes': self.codes})
|
||||
|
|
Loading…
Reference in New Issue
Block a user