Clean stuff up, move more things into arch_util
This commit is contained in:
parent
a6f0f854b9
commit
f2a31702b5
|
@ -8,6 +8,38 @@ import torch.nn.functional as F
|
|||
import torch.nn.utils.spectral_norm as SpectralNorm
|
||||
from math import sqrt
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p = 2, dim = -1)
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories, eps = 1e-5):
|
||||
return (x + eps) / (x.sum() + n_categories * eps)
|
||||
|
||||
|
||||
def sample_vectors(samples, num):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device = device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device = device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kaiming_init(module,
|
||||
a=0,
|
||||
mode='fan_out',
|
||||
|
@ -24,9 +56,11 @@ def kaiming_init(module,
|
|||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def pixel_norm(x, epsilon=1e-8):
|
||||
return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon)
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
|
@ -75,20 +109,12 @@ def default_init_weights(module, scale=1):
|
|||
elif isinstance(m, nn.Linear):
|
||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||
m.weight.data *= scale
|
||||
"""
|
||||
Various utilities for neural networks.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * th.sigmoid(x)
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
|
|
|
@ -1,387 +0,0 @@
|
|||
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.gpt_voice.lucidrains_dvae import eval_decorator
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
||||
from models.vqvae.gumbel_quantizer import GumbelQuantizer
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import get_mask_from_lengths
|
||||
import models.gpt_voice.mini_encoder as menc
|
||||
|
||||
|
||||
class DiscreteEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
dropout,
|
||||
scale):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
conv_nd(1, in_channels, model_channels, 3, padding=1),
|
||||
menc.ResBlock(model_channels, dropout, dims=1),
|
||||
Downsample(model_channels, use_conv=True, dims=1, out_channels=model_channels*2, factor=scale),
|
||||
menc.ResBlock(model_channels*2, dropout, dims=1),
|
||||
Downsample(model_channels*2, use_conv=True, dims=1, out_channels=model_channels*4, factor=scale),
|
||||
menc.ResBlock(model_channels*4, dropout, dims=1),
|
||||
AttentionBlock(model_channels*4, num_heads=4),
|
||||
menc.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1),
|
||||
)
|
||||
|
||||
def forward(self, spectrogram):
|
||||
return self.blocks(spectrogram)
|
||||
|
||||
|
||||
class DiscreteDecoder(nn.Module):
|
||||
def __init__(self, in_channels, level_channels, scale):
|
||||
super().__init__()
|
||||
# Just raw upsampling, return a dict with each layer.
|
||||
self.init = conv_nd(1, in_channels, level_channels[0], kernel_size=3, padding=1)
|
||||
layers = []
|
||||
for i, lvl in enumerate(level_channels[:-1]):
|
||||
layers.append(nn.Sequential(normalization(lvl),
|
||||
nn.SiLU(lvl),
|
||||
Upsample(lvl, use_conv=True, dims=1, out_channels=level_channels[i+1], factor=scale)))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.init(x)
|
||||
outs = [y]
|
||||
for layer in self.layers:
|
||||
y = layer(y)
|
||||
outs.append(y)
|
||||
return outs
|
||||
|
||||
|
||||
class DiffusionDVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_channels,
|
||||
num_res_blocks,
|
||||
in_channels=1,
|
||||
out_channels=2, # mean and variance
|
||||
spectrogram_channels=80,
|
||||
spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform.
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8, 16, 32, 64),
|
||||
attention_resolutions=(16,32,64),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
use_new_attention_order=False,
|
||||
kernel_size=5,
|
||||
quantize_dim=1024,
|
||||
num_discrete_codes=8192,
|
||||
scale_steps=4,
|
||||
conditioning_inputs_provided=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.spectrogram_channels = spectrogram_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.dims = dims
|
||||
self.spectrogram_conditioning_levels = spectrogram_conditioning_levels
|
||||
self.scale_steps = scale_steps
|
||||
|
||||
self.encoder = DiscreteEncoder(spectrogram_channels, model_channels*4, quantize_dim, dropout, scale_steps)
|
||||
#self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
|
||||
self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes)
|
||||
# For recording codebook usage.
|
||||
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
self.internal_step = 0
|
||||
decoder_channels = [model_channels * channel_mult[s-1] for s in spectrogram_conditioning_levels]
|
||||
self.decoder = DiscreteDecoder(quantize_dim, decoder_channels[::-1], scale_steps)
|
||||
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.conditioning_enabled = conditioning_inputs_provided
|
||||
if conditioning_inputs_provided:
|
||||
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
|
||||
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
|
||||
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
self.convergence_convs = nn.ModuleList([])
|
||||
for level, mult in enumerate(channel_mult):
|
||||
if level in spectrogram_conditioning_levels:
|
||||
self.convergence_convs.append(conv_nd(dims, ch*2, ch, 1))
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_steps
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_steps)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
# Note: this is very poor design, but quantizer.get_temperature not only retrieves the temperature, it also updates the step and thus it is extremely important that this function get called regularly.
|
||||
return {'histogram_codes': self.codes, 'quantizer_temperature': self.quantizer.get_temperature(step)}
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def get_codebook_indices(self, images):
|
||||
img = self.norm(images)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
return codes
|
||||
|
||||
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals):
|
||||
if self.conditioning_enabled:
|
||||
assert conditioning_inputs is not None
|
||||
|
||||
spec_hs = self.decoder(embeddings)[::-1]
|
||||
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
|
||||
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
|
||||
convergence_fns = list(self.convergence_convs)
|
||||
|
||||
# Timestep embeddings and conditioning signals are combined using a small transformer.
|
||||
hs = []
|
||||
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
if self.conditioning_enabled:
|
||||
mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add.
|
||||
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
|
||||
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
|
||||
emb = self.embedding_combiner(emb, mask, self.query_gen(spec_hs[0]))
|
||||
else:
|
||||
emb = emb1
|
||||
|
||||
# The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder.
|
||||
next_spec = spec_hs.pop(0)
|
||||
next_convergence_fn = convergence_fns.pop(0)
|
||||
h = x.type(self.dtype)
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb)
|
||||
if next_spec is not None and h.shape[-1] == next_spec.shape[-1]:
|
||||
h = torch.cat([h, next_spec], dim=1)
|
||||
h = next_convergence_fn(h)
|
||||
if len(spec_hs) > 0:
|
||||
next_spec = spec_hs.pop(0)
|
||||
next_convergence_fn = convergence_fns.pop(0)
|
||||
else:
|
||||
next_spec = None
|
||||
hs.append(h)
|
||||
assert len(spec_hs) == 0
|
||||
assert len(convergence_fns) == 0
|
||||
h = self.middle_block(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
def decode(self, x, timesteps, codes, conditioning_inputs=None, num_conditioning_signals=None):
|
||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
|
||||
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals)
|
||||
|
||||
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None):
|
||||
# Compute DVAE portion first.
|
||||
spec_logits = self.encoder(spectrogram).permute((0,2,1))
|
||||
sampled, commitment_loss, codes = self.quantizer(spec_logits)
|
||||
|
||||
if self.training:
|
||||
# Compute from softmax outputs to preserve gradients.
|
||||
embeddings = sampled.permute((0,2,1))
|
||||
else:
|
||||
# Compute from codes only.
|
||||
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
|
||||
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
if self.internal_step % 50 == 0:
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
self.code_ind = 0
|
||||
self.internal_step += 1
|
||||
|
||||
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss
|
||||
|
||||
|
||||
@register_model
|
||||
def register_unet_diffusion_dvae(opt_net, opt):
|
||||
return DiffusionDVAE(**opt_net['kwargs'])
|
||||
|
||||
|
||||
|
||||
'''
|
||||
|
||||
|
||||
class DiffusionDVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_channels,
|
||||
num_res_blocks,
|
||||
in_channels=1,
|
||||
out_channels=2, # mean and variance
|
||||
spectrogram_channels=80,
|
||||
spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform.
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8, 16, 32, 64),
|
||||
attention_resolutions=(16,32,64),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
use_new_attention_order=False,
|
||||
kernel_size=5,
|
||||
quantize_dim=1024,
|
||||
num_discrete_codes=8192,
|
||||
scale_steps=4,
|
||||
conditioning_inputs_provided=True,
|
||||
):
|
||||
'''
|
||||
|
||||
# Test for ~4 second audio clip at 22050Hz
|
||||
if __name__ == '__main__':
|
||||
spec = torch.randn(4, 80, 160)
|
||||
ts = torch.LongTensor([432, 234, 100, 555])
|
||||
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
|
||||
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
|
||||
print(model(torch.randn_like(spec), ts, spec)[0].shape)
|
|
@ -1,179 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.distributed as distributed
|
||||
|
||||
from models.vqvae.vqvae import ResBlock, Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
|
||||
super().__init__()
|
||||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Conv2d(channel, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.LeakyReLU(inplace=True))
|
||||
|
||||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
UpsampleConv(channel, channel // 2, 5, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
UpsampleConv(
|
||||
channel // 2, out_channel, 5, padding=2
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
UpsampleConv(channel, out_channel, 5, padding=2)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class VQVAE3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel=3,
|
||||
channel=128,
|
||||
n_res_block=2,
|
||||
n_res_channel=32,
|
||||
codebook_dim=64,
|
||||
codebook_size=512,
|
||||
decay=0.99,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True)])
|
||||
self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
|
||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||
self.dec_t = Decoder(
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, 5, padding=2
|
||||
)
|
||||
self.dec = Decoder(
|
||||
codebook_dim + codebook_dim,
|
||||
32,
|
||||
channel,
|
||||
n_res_block,
|
||||
n_res_channel,
|
||||
stride=4,
|
||||
)
|
||||
self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1)
|
||||
|
||||
def forward(self, input):
|
||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec, diff
|
||||
|
||||
def encode(self, input):
|
||||
fea = self.initial_conv(input)
|
||||
enc_b = checkpoint(self.enc_b, fea)
|
||||
enc_t = checkpoint(self.enc_t, enc_b)
|
||||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
upsample_t = self.upsample_t(quant_t)
|
||||
quant = torch.cat([upsample_t, quant_b], 1)
|
||||
dec = checkpoint(self.dec, quant)
|
||||
dec = checkpoint(self.final_conv, dec)
|
||||
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae3(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return VQVAE3(**kw)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = VQVAE3()
|
||||
print(v(torch.randn(1,3,128,128))[0].shape)
|
|
@ -1,279 +0,0 @@
|
|||
import os
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.distributed as distributed
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
||||
convert_conv_net_state_dict_to_switched_conv
|
||||
from models.vqvae.vqvae import ResBlock, Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, kernel_size, padding, cfg):
|
||||
super().__init__()
|
||||
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth=cfg['breadth'], include_coupler=True,
|
||||
coupler_mode=cfg['mode'], coupler_dim_in=in_filters, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled'])
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, cfg):
|
||||
super().__init__()
|
||||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 5, breadth=cfg['breadth'], stride=2, include_coupler=True,
|
||||
coupler_mode=cfg['mode'], coupler_dim_in=channel // 2, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'],
|
||||
coupler_dim_in=channel, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'],
|
||||
coupler_dim_in=channel // 2, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, cfg
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'],
|
||||
coupler_dim_in=in_channel, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled'])]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.LeakyReLU(inplace=True))
|
||||
|
||||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
UpsampleConv(channel, channel // 2, 5, padding=2, cfg=cfg),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
UpsampleConv(
|
||||
channel // 2, out_channel, 5, padding=2, cfg=cfg
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
UpsampleConv(channel, out_channel, 5, padding=2, cfg=cfg)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class VQVAE3HardSwitch(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel=3,
|
||||
channel=128,
|
||||
n_res_block=2,
|
||||
n_res_channel=32,
|
||||
codebook_dim=64,
|
||||
codebook_size=512,
|
||||
decay=0.99,
|
||||
cfg={'mode':'standard', 'breadth':16, 'hard_enabled': True, 'dropout': 0.4}
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True)])
|
||||
self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4, cfg=cfg)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, cfg=cfg)
|
||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||
self.dec_t = Decoder(
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, cfg=cfg
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, 5, padding=2, cfg=cfg
|
||||
)
|
||||
self.dec = Decoder(
|
||||
codebook_dim + codebook_dim,
|
||||
32,
|
||||
channel,
|
||||
n_res_block,
|
||||
n_res_channel,
|
||||
stride=4,
|
||||
cfg=cfg
|
||||
)
|
||||
self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1)
|
||||
|
||||
def forward(self, input):
|
||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec, diff
|
||||
|
||||
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
|
||||
from matplotlib import cm
|
||||
magnitude, indices = torch.topk(attention_out, 3, dim=1)
|
||||
indices = indices.cpu()
|
||||
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
|
||||
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
|
||||
img = img.permute((0, 3, 1, 2))
|
||||
torchvision.utils.save_image(img, output_file)
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
|
||||
for i, c in enumerate(convs):
|
||||
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, 16)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]),
|
||||
('enc_b_blk4', self.enc_b.blocks[4]),
|
||||
('enc_t_blk2', self.enc_t.blocks[2]),
|
||||
('dec_t_blk0', self.dec_t.blocks[0]),
|
||||
('dec_t_blk-1', self.dec_t.blocks[-1].conv),
|
||||
('dec_blk0', self.dec.blocks[0]),
|
||||
('dec_blk-1', self.dec.blocks[-1].conv),
|
||||
('dec_blk-3', self.dec.blocks[-3].conv)]
|
||||
logs = {}
|
||||
for name, swc in switched_convs:
|
||||
logs[f'{name}_histogram_switch_usage'] = swc.latest_masks
|
||||
return logs
|
||||
|
||||
def encode(self, input):
|
||||
fea = self.initial_conv(input)
|
||||
enc_b = checkpoint(self.enc_b, fea)
|
||||
enc_t = checkpoint(self.enc_t, enc_b)
|
||||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
upsample_t = self.upsample_t(quant_t)
|
||||
quant = torch.cat([upsample_t, quant_b], 1)
|
||||
dec = checkpoint(self.dec, quant)
|
||||
dec = checkpoint(self.final_conv, dec)
|
||||
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
|
||||
def convert_weights(weights_file):
|
||||
sd = torch.load(weights_file)
|
||||
from models.vqvae.vqvae_3 import VQVAE3
|
||||
std_model = VQVAE3()
|
||||
std_model.load_state_dict(sd)
|
||||
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 16, ['quantize_conv_t', 'quantize_conv_b',
|
||||
'enc_b.blocks.0', 'enc_t.blocks.0',
|
||||
'conv.1', 'conv.3', 'initial_conv', 'final_conv'])
|
||||
torch.save(nsd, "converted.pth")
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae3_hard_switch(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
vq = VQVAE3HardSwitch(**kw)
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
|
||||
return vq
|
||||
|
||||
|
||||
def performance_test():
|
||||
# For breadth=32:
|
||||
# Custom_cuda_naive: 15.4
|
||||
# Torch_native: 29.2s
|
||||
#
|
||||
# For breadth=8
|
||||
# Custom_cuda_naive: 9.8
|
||||
# Torch_native: 10s
|
||||
cfg = {
|
||||
'mode': 'lambda',
|
||||
'breadth': 16,
|
||||
'hard_enabled': True,
|
||||
'dropout': 0,
|
||||
}
|
||||
net = VQVAE3HardSwitch(cfg=cfg).to('cuda').double()
|
||||
cfg['hard_enabled'] = False
|
||||
netO = VQVAE3HardSwitch(cfg=cfg).double()
|
||||
netO.load_state_dict(net.state_dict())
|
||||
netO = netO.cpu()
|
||||
|
||||
loss = nn.L1Loss()
|
||||
opt = torch.optim.Adam(net.parameters(), lr=1e-4)
|
||||
started = time()
|
||||
for j in tqdm(range(10)):
|
||||
inp = torch.rand((4, 3, 64, 64), device='cuda', dtype=torch.double)
|
||||
res = net(inp)[0]
|
||||
l = loss(res, inp)
|
||||
l.backward()
|
||||
|
||||
res2 = netO(inp.cpu())[0]
|
||||
l = loss(res2, inp.cpu())
|
||||
l.backward()
|
||||
|
||||
for p, op in zip(net.parameters(), netO.parameters()):
|
||||
diff = p.grad.cpu() - op.grad
|
||||
print(diff.max())
|
||||
|
||||
opt.step()
|
||||
net.zero_grad()
|
||||
print("Elapsed: ", (time()-started))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#v = VQVAE3HardSwitch()
|
||||
#print(v(torch.randn(1,3,128,128))[0].shape)
|
||||
#convert_weights("../../../experiments/vqvae_base.pth")
|
||||
performance_test()
|
|
@ -1,179 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.distributed as distributed
|
||||
|
||||
from models.vqvae.vqvae import ResBlock, Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
|
||||
super().__init__()
|
||||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Conv2d(channel, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.LeakyReLU(inplace=True))
|
||||
|
||||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
UpsampleConv(channel, channel // 2, 5, padding=2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
UpsampleConv(
|
||||
channel // 2, out_channel, 5, padding=2
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
UpsampleConv(channel, out_channel, 5, padding=2)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class VQVAE3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel=3,
|
||||
channel=128,
|
||||
n_res_block=2,
|
||||
n_res_channel=32,
|
||||
codebook_dim=64,
|
||||
codebook_size=512,
|
||||
decay=0.99,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True)])
|
||||
self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
|
||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||
self.dec_t = Decoder(
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, 5, padding=2
|
||||
)
|
||||
self.dec = Decoder(
|
||||
codebook_dim + codebook_dim,
|
||||
32,
|
||||
channel,
|
||||
n_res_block,
|
||||
n_res_channel,
|
||||
stride=4,
|
||||
)
|
||||
self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1)
|
||||
|
||||
def forward(self, input):
|
||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec, diff
|
||||
|
||||
def encode(self, input):
|
||||
fea = self.initial_conv(input)
|
||||
enc_b = checkpoint(self.enc_b, fea)
|
||||
enc_t = checkpoint(self.enc_t, enc_b)
|
||||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
upsample_t = self.upsample_t(quant_t)
|
||||
quant = torch.cat([upsample_t, quant_b], 1)
|
||||
dec = checkpoint(self.dec, quant)
|
||||
dec = checkpoint(self.final_conv, dec)
|
||||
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae3(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return VQVAE3(**kw)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = VQVAE3()
|
||||
print(v(torch.randn(1,3,128,128))[0].shape)
|
|
@ -1,149 +0,0 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.distributed as distributed
|
||||
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
layer_norm_eps=1e-5, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True,
|
||||
**factory_kwargs)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
||||
self.norm1 = nn.BatchNorm1d(d_model)
|
||||
self.norm2 = nn.BatchNorm1d(d_model)
|
||||
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def __setstate__(self, state):
|
||||
if 'activation' not in state:
|
||||
state['activation'] = F.relu
|
||||
super(TransformerEncoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||
src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + src2
|
||||
src = self.norm1(src.permute(0,2,1)).permute(0,2,1)
|
||||
src2 = self.linear2(self.activation(self.linear1(src)))
|
||||
src = src + src2
|
||||
src = self.norm2(src.permute(0,2,1)).permute(0,2,1)
|
||||
return src
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, output_breadth, num_layers=8, compression_factor=8):
|
||||
super().__init__()
|
||||
|
||||
self.compression_factor = compression_factor
|
||||
self.pre_conv_stack = nn.Sequential(nn.Conv1d(in_channel, channel//4, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(channel//4, channel//2, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(channel//2, channel//2, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(channel//2, channel, kernel_size=3, stride=2, padding=1))
|
||||
self.norm1 = nn.BatchNorm1d(channel)
|
||||
self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth//4)
|
||||
self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.norm1(self.pre_conv_stack(input)).permute(0,2,1)
|
||||
x = self.positional_embeddings(x)
|
||||
x = self.encode(x)
|
||||
return x[:,:input.shape[2]//self.compression_factor,:]
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, output_breadth, num_layers=6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.initial_conv = nn.Conv1d(in_channel, channel, kernel_size=1)
|
||||
self.expand = output_breadth
|
||||
self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth)
|
||||
self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers)
|
||||
self.final_conv_stack = nn.Sequential(nn.Conv1d(channel, channel, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(channel, out_channel, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, input):
|
||||
x = self.initial_conv(input.permute(0,2,1)).permute(0,2,1)
|
||||
x = nn.functional.pad(x, (0,0,0, self.expand-input.shape[1]))
|
||||
x = self.positional_embeddings(x)
|
||||
x = self.encode(x).permute(0,2,1)
|
||||
return self.final_conv_stack(x)
|
||||
|
||||
|
||||
class VQVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
data_channels=1,
|
||||
channel=256,
|
||||
codebook_dim=256,
|
||||
codebook_size=512,
|
||||
breadth=80,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enc = Encoder(data_channels, channel, breadth)
|
||||
self.quantize_dense = nn.Linear(channel, codebook_dim)
|
||||
self.quantize = Quantize(codebook_dim, codebook_size)
|
||||
self.dec = Decoder(codebook_dim, data_channels, channel, breadth)
|
||||
|
||||
def forward(self, input):
|
||||
input = input.unsqueeze(1)
|
||||
quant, diff, _ = self.encode(input)
|
||||
dec = checkpoint(self.dec, quant)
|
||||
dec = dec.squeeze(1)
|
||||
return dec, diff
|
||||
|
||||
def encode(self, input):
|
||||
enc = checkpoint(self.enc, input)
|
||||
quant = self.quantize_dense(enc)
|
||||
quant, diff, id = self.quantize(quant)
|
||||
diff = diff.unsqueeze(0)
|
||||
return quant, diff, id
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae_xform_audio(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
vq = VQVAE(**kw)
|
||||
return vq
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = VQVAE()
|
||||
res=model(torch.randn(4,80))
|
||||
print(res[0].shape)
|
|
@ -1,265 +0,0 @@
|
|||
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and
|
||||
# a "standard" upsampler rather than ConvTranspose.
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.distributed as distributed
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class Quantize(nn.Module):
|
||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.n_embed = n_embed
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
|
||||
embed = torch.randn(dim, n_embed)
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
def forward(self, input):
|
||||
flatten = input.reshape(-1, self.dim)
|
||||
dist = (
|
||||
flatten.pow(2).sum(1, keepdim=True)
|
||||
- 2 * flatten @ self.embed
|
||||
+ self.embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
_, embed_ind = (-dist).max(1)
|
||||
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
||||
embed_ind = embed_ind.view(*input.shape[:-1])
|
||||
quantize = self.embed_code(embed_ind)
|
||||
|
||||
if self.training:
|
||||
embed_onehot_sum = embed_onehot.sum(0)
|
||||
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
||||
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(embed_onehot_sum)
|
||||
distributed.all_reduce(embed_sum)
|
||||
|
||||
self.cluster_size.data.mul_(self.decay).add_(
|
||||
embed_onehot_sum, alpha=1 - self.decay
|
||||
)
|
||||
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
||||
n = self.cluster_size.sum()
|
||||
cluster_size = (
|
||||
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
diff = (quantize.detach() - input).pow(2).mean()
|
||||
quantize = input + (quantize - input).detach()
|
||||
|
||||
return quantize, diff, embed_ind
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, channel):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channel, channel, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel, in_channel, 1),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv(input)
|
||||
out += input
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
|
||||
super().__init__()
|
||||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
UpsampleConv(channel, channel // 2, 5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
UpsampleConv(
|
||||
channel // 2, out_channel, 5, padding=2
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
UpsampleConv(channel, out_channel, 5, padding=2)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class VQVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel=3,
|
||||
channel=128,
|
||||
n_res_block=2,
|
||||
n_res_channel=32,
|
||||
codebook_dim=64,
|
||||
codebook_size=512,
|
||||
decay=0.99,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
|
||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||
self.dec_t = Decoder(
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, 5, padding=2
|
||||
)
|
||||
self.dec = Decoder(
|
||||
codebook_dim + codebook_dim,
|
||||
in_channel,
|
||||
channel,
|
||||
n_res_block,
|
||||
n_res_channel,
|
||||
stride=4,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec, diff
|
||||
|
||||
def encode(self, input):
|
||||
enc_b = checkpoint(self.enc_b, input)
|
||||
enc_t = checkpoint(self.enc_t, enc_b)
|
||||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
upsample_t = self.upsample_t(quant_t)
|
||||
quant = torch.cat([upsample_t, quant_b], 1)
|
||||
dec = checkpoint(self.dec, quant)
|
||||
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae_normalized(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return VQVAE(**kw)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = VQVAE()
|
||||
print(v(torch.randn(1,3,128,128))[0].shape)
|
|
@ -1,293 +0,0 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.distributed as distributed
|
||||
|
||||
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
||||
convert_conv_net_state_dict_to_switched_conv
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4)
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class Quantize(nn.Module):
|
||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.n_embed = n_embed
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
|
||||
embed = torch.randn(dim, n_embed)
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
def forward(self, input):
|
||||
flatten = input.reshape(-1, self.dim)
|
||||
dist = (
|
||||
flatten.pow(2).sum(1, keepdim=True)
|
||||
- 2 * flatten @ self.embed
|
||||
+ self.embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
_, embed_ind = (-dist).max(1)
|
||||
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
||||
embed_ind = embed_ind.view(*input.shape[:-1])
|
||||
quantize = self.embed_code(embed_ind)
|
||||
|
||||
if self.training:
|
||||
embed_onehot_sum = embed_onehot.sum(0)
|
||||
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
||||
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(embed_onehot_sum)
|
||||
distributed.all_reduce(embed_sum)
|
||||
|
||||
self.cluster_size.data.mul_(self.decay).add_(
|
||||
embed_onehot_sum, alpha=1 - self.decay
|
||||
)
|
||||
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
||||
n = self.cluster_size.sum()
|
||||
cluster_size = (
|
||||
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
diff = (quantize.detach() - input).pow(2).mean()
|
||||
quantize = input + (quantize - input).detach()
|
||||
|
||||
return quantize, diff, embed_ind
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, channel, breadth):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channel, channel, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel, in_channel, 1),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv(input)
|
||||
out += input
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth):
|
||||
super().__init__()
|
||||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
UpsampleConv(channel, channel // 2, breadth, 5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
UpsampleConv(
|
||||
channel // 2, out_channel, breadth, 5, padding=2
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
UpsampleConv(channel, out_channel, breadth, 5, padding=2)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class VQVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel=3,
|
||||
channel=128,
|
||||
n_res_block=2,
|
||||
n_res_channel=32,
|
||||
codebook_dim=64,
|
||||
codebook_size=512,
|
||||
decay=0.99,
|
||||
breadth=8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.breadth = breadth
|
||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth)
|
||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||
self.dec_t = Decoder(
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, breadth, 5, padding=2
|
||||
)
|
||||
self.dec = Decoder(
|
||||
codebook_dim + codebook_dim,
|
||||
in_channel,
|
||||
channel,
|
||||
n_res_block,
|
||||
n_res_channel,
|
||||
stride=4,
|
||||
breadth=breadth
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec, diff
|
||||
|
||||
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
|
||||
from matplotlib import cm
|
||||
magnitude, indices = torch.topk(attention_out, 3, dim=1)
|
||||
indices = indices.cpu()
|
||||
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
|
||||
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
|
||||
img = img.permute((0, 3, 1, 2))
|
||||
torchvision.utils.save_image(img, output_file)
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
|
||||
for i, c in enumerate(convs):
|
||||
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]),
|
||||
('enc_b_blk4', self.enc_b.blocks[4]),
|
||||
('enc_t_blk2', self.enc_t.blocks[2]),
|
||||
('dec_t_blk0', self.dec_t.blocks[0]),
|
||||
('dec_t_blk-1', self.dec_t.blocks[-1].conv),
|
||||
('dec_blk0', self.dec.blocks[0]),
|
||||
('dec_blk-1', self.dec.blocks[-1].conv),
|
||||
('dec_blk-3', self.dec.blocks[-3].conv)]
|
||||
logs = {}
|
||||
for name, swc in switched_convs:
|
||||
logs[f'{name}_histogram_switch_usage'] = swc.latest_masks
|
||||
return logs
|
||||
|
||||
def encode(self, input):
|
||||
enc_b = checkpoint(self.enc_b, input)
|
||||
enc_t = checkpoint(self.enc_t, enc_b)
|
||||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
upsample_t = self.upsample_t(quant_t)
|
||||
quant = torch.cat([upsample_t, quant_b], 1)
|
||||
dec = checkpoint(self.dec, quant)
|
||||
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def convert_weights(weights_file):
|
||||
sd = torch.load(weights_file)
|
||||
import models.vqvae.vqvae_no_conv_transpose as stdvq
|
||||
std_model = stdvq.VQVAE()
|
||||
std_model.load_state_dict(sd)
|
||||
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 8, ['quantize_conv_t', 'quantize_conv_b',
|
||||
'enc_b.blocks.0', 'enc_t.blocks.0',
|
||||
'conv.1', 'conv.3'])
|
||||
torch.save(nsd, "converted.pth")
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae_norm_hard_switched_conv_lambda(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return VQVAE(**kw)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = VQVAE(breadth=8).cuda()
|
||||
print(v(torch.randn(1,3,128,128).cuda())[0].shape)
|
||||
#convert_weights("../../../experiments/50000_generator.pth")
|
|
@ -1,276 +0,0 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.distributed as distributed
|
||||
|
||||
from models.switched_conv.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.conv = SwitchedConv(in_filters, out_filters, kernel_size, breadth, padding=padding, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters)
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class Quantize(nn.Module):
|
||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.n_embed = n_embed
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
|
||||
embed = torch.randn(dim, n_embed)
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
def forward(self, input):
|
||||
flatten = input.reshape(-1, self.dim)
|
||||
dist = (
|
||||
flatten.pow(2).sum(1, keepdim=True)
|
||||
- 2 * flatten @ self.embed
|
||||
+ self.embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
_, embed_ind = (-dist).max(1)
|
||||
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
||||
embed_ind = embed_ind.view(*input.shape[:-1])
|
||||
quantize = self.embed_code(embed_ind)
|
||||
|
||||
if self.training:
|
||||
embed_onehot_sum = embed_onehot.sum(0)
|
||||
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
||||
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(embed_onehot_sum)
|
||||
distributed.all_reduce(embed_sum)
|
||||
|
||||
self.cluster_size.data.mul_(self.decay).add_(
|
||||
embed_onehot_sum, alpha=1 - self.decay
|
||||
)
|
||||
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
||||
n = self.cluster_size.sum()
|
||||
cluster_size = (
|
||||
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
diff = (quantize.detach() - input).pow(2).mean()
|
||||
quantize = input + (quantize - input).detach()
|
||||
|
||||
return quantize, diff, embed_ind
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, channel, breadth):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConv(channel, in_channel, 1, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv(input)
|
||||
out += input
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth):
|
||||
super().__init__()
|
||||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConv(channel // 2, channel, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConv(channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
|
||||
nn.ReLU(inplace=True),
|
||||
SwitchedConv(channel // 2, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2),
|
||||
]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = [SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel)]
|
||||
|
||||
for i in range(n_res_block):
|
||||
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
||||
|
||||
blocks.append(nn.ReLU(inplace=True))
|
||||
|
||||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
UpsampleConv(channel, channel // 2, breadth, 5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
UpsampleConv(
|
||||
channel // 2, out_channel, breadth, 5, padding=2
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
UpsampleConv(channel, out_channel, breadth, 5, padding=2)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, input):
|
||||
return self.blocks(input)
|
||||
|
||||
|
||||
class VQVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel=3,
|
||||
channel=128,
|
||||
n_res_block=2,
|
||||
n_res_channel=32,
|
||||
codebook_dim=64,
|
||||
codebook_size=512,
|
||||
decay=0.99,
|
||||
breadth=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.breadth = breadth
|
||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth)
|
||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||
self.quantize_t = Quantize(codebook_dim, codebook_size)
|
||||
self.dec_t = Decoder(
|
||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, breadth, 5, padding=2
|
||||
)
|
||||
self.dec = Decoder(
|
||||
codebook_dim + codebook_dim,
|
||||
in_channel,
|
||||
channel,
|
||||
n_res_block,
|
||||
n_res_channel,
|
||||
stride=4,
|
||||
breadth=breadth
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
quant_t, quant_b, diff, _, _ = self.encode(input)
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec, diff
|
||||
|
||||
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
|
||||
from matplotlib import cm
|
||||
magnitude, indices = torch.topk(attention_out, 3, dim=1)
|
||||
indices = indices.cpu()
|
||||
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
|
||||
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
|
||||
img = img.permute((0, 3, 1, 2))
|
||||
torchvision.utils.save_image(img, output_file)
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
|
||||
for i, c in enumerate(convs):
|
||||
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth)
|
||||
|
||||
def encode(self, input):
|
||||
enc_b = checkpoint(self.enc_b, input)
|
||||
enc_t = checkpoint(self.enc_t, enc_b)
|
||||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
enc_b = torch.cat([dec_t, enc_b], 1)
|
||||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
upsample_t = self.upsample_t(quant_t)
|
||||
quant = torch.cat([upsample_t, quant_b], 1)
|
||||
dec = checkpoint(self.dec, quant)
|
||||
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute(0, 3, 1, 2)
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute(0, 3, 1, 2)
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def convert_weights(weights_file):
|
||||
sd = torch.load(weights_file)
|
||||
import models.vqvae.vqvae_no_conv_transpose as stdvq
|
||||
std_model = stdvq.VQVAE()
|
||||
std_model.load_state_dict(sd)
|
||||
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 4, ['quantize_conv_t', 'quantize_conv_b'])
|
||||
torch.save(nsd, "converted.pth")
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae_norm_switched_conv_lambda(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return VQVAE(**kw)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#v = VQVAE()
|
||||
#print(v(torch.randn(1,3,128,128))[0].shape)
|
||||
convert_weights("../../../experiments/4000_generator.pth")
|
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-17.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-20.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
|
@ -284,7 +284,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_audio_clips_with_quantizer_compression.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user