Clean stuff up, move more things into arch_util

This commit is contained in:
James Betker 2021-10-20 21:19:25 -06:00
parent a6f0f854b9
commit f2a31702b5
11 changed files with 37 additions and 2018 deletions

View File

@ -8,6 +8,38 @@ import torch.nn.functional as F
import torch.nn.utils.spectral_norm as SpectralNorm
from math import sqrt
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def l2norm(t):
return F.normalize(t, p = 2, dim = -1)
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
def laplace_smoothing(x, n_categories, eps = 1e-5):
return (x + eps) / (x.sum() + n_categories * eps)
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device = device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device = device)
return samples[indices]
def kaiming_init(module,
a=0,
mode='fan_out',
@ -24,9 +56,11 @@ def kaiming_init(module,
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def pixel_norm(x, epsilon=1e-8):
return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon)
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
@ -75,20 +109,12 @@ def default_init_weights(module, scale=1):
elif isinstance(m, nn.Linear):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
"""
Various utilities for neural networks.
"""
import math
import torch as th
import torch.nn as nn
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * th.sigmoid(x)
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):

View File

@ -1,387 +0,0 @@
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.respace import SpacedDiffusion, space_timesteps
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
Downsample, Upsample
import torch
import torch.nn as nn
from models.gpt_voice.lucidrains_dvae import eval_decorator
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from models.vqvae.gumbel_quantizer import GumbelQuantizer
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import get_mask_from_lengths
import models.gpt_voice.mini_encoder as menc
class DiscreteEncoder(nn.Module):
def __init__(self,
in_channels,
model_channels,
out_channels,
dropout,
scale):
super().__init__()
self.blocks = nn.Sequential(
conv_nd(1, in_channels, model_channels, 3, padding=1),
menc.ResBlock(model_channels, dropout, dims=1),
Downsample(model_channels, use_conv=True, dims=1, out_channels=model_channels*2, factor=scale),
menc.ResBlock(model_channels*2, dropout, dims=1),
Downsample(model_channels*2, use_conv=True, dims=1, out_channels=model_channels*4, factor=scale),
menc.ResBlock(model_channels*4, dropout, dims=1),
AttentionBlock(model_channels*4, num_heads=4),
menc.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1),
)
def forward(self, spectrogram):
return self.blocks(spectrogram)
class DiscreteDecoder(nn.Module):
def __init__(self, in_channels, level_channels, scale):
super().__init__()
# Just raw upsampling, return a dict with each layer.
self.init = conv_nd(1, in_channels, level_channels[0], kernel_size=3, padding=1)
layers = []
for i, lvl in enumerate(level_channels[:-1]):
layers.append(nn.Sequential(normalization(lvl),
nn.SiLU(lvl),
Upsample(lvl, use_conv=True, dims=1, out_channels=level_channels[i+1], factor=scale)))
self.layers = nn.ModuleList(layers)
def forward(self, x):
y = self.init(x)
outs = [y]
for layer in self.layers:
y = layer(y)
outs.append(y)
return outs
class DiffusionDVAE(nn.Module):
def __init__(
self,
model_channels,
num_res_blocks,
in_channels=1,
out_channels=2, # mean and variance
spectrogram_channels=80,
spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform.
dropout=0,
channel_mult=(1, 2, 4, 8, 16, 32, 64),
attention_resolutions=(16,32,64),
conv_resample=True,
dims=1,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
use_new_attention_order=False,
kernel_size=5,
quantize_dim=1024,
num_discrete_codes=8192,
scale_steps=4,
conditioning_inputs_provided=True,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.spectrogram_channels = spectrogram_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.dims = dims
self.spectrogram_conditioning_levels = spectrogram_conditioning_levels
self.scale_steps = scale_steps
self.encoder = DiscreteEncoder(spectrogram_channels, model_channels*4, quantize_dim, dropout, scale_steps)
#self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes)
# For recording codebook usage.
self.codes = torch.zeros((1228800,), dtype=torch.long)
self.code_ind = 0
self.internal_step = 0
decoder_channels = [model_channels * channel_mult[s-1] for s in spectrogram_conditioning_levels]
self.decoder = DiscreteDecoder(quantize_dim, decoder_channels[::-1], scale_steps)
padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.conditioning_enabled = conditioning_inputs_provided
if conditioning_inputs_provided:
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
self.convergence_convs = nn.ModuleList([])
for level, mult in enumerate(channel_mult):
if level in spectrogram_conditioning_levels:
self.convergence_convs.append(conv_nd(dims, ch*2, ch, 1))
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_steps
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_steps)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
)
def get_debug_values(self, step, __):
# Note: this is very poor design, but quantizer.get_temperature not only retrieves the temperature, it also updates the step and thus it is extremely important that this function get called regularly.
return {'histogram_codes': self.codes, 'quantizer_temperature': self.quantizer.get_temperature(step)}
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
return codes
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals):
if self.conditioning_enabled:
assert conditioning_inputs is not None
spec_hs = self.decoder(embeddings)[::-1]
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
convergence_fns = list(self.convergence_convs)
# Timestep embeddings and conditioning signals are combined using a small transformer.
hs = []
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.conditioning_enabled:
mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add.
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
emb = self.embedding_combiner(emb, mask, self.query_gen(spec_hs[0]))
else:
emb = emb1
# The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder.
next_spec = spec_hs.pop(0)
next_convergence_fn = convergence_fns.pop(0)
h = x.type(self.dtype)
for k, module in enumerate(self.input_blocks):
h = module(h, emb)
if next_spec is not None and h.shape[-1] == next_spec.shape[-1]:
h = torch.cat([h, next_spec], dim=1)
h = next_convergence_fn(h)
if len(spec_hs) > 0:
next_spec = spec_hs.pop(0)
next_convergence_fn = convergence_fns.pop(0)
else:
next_spec = None
hs.append(h)
assert len(spec_hs) == 0
assert len(convergence_fns) == 0
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
def decode(self, x, timesteps, codes, conditioning_inputs=None, num_conditioning_signals=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals)
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None):
# Compute DVAE portion first.
spec_logits = self.encoder(spectrogram).permute((0,2,1))
sampled, commitment_loss, codes = self.quantizer(spec_logits)
if self.training:
# Compute from softmax outputs to preserve gradients.
embeddings = sampled.permute((0,2,1))
else:
# Compute from codes only.
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
# This is so we can debug the distribution of codes being learned.
if self.internal_step % 50 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.internal_step += 1
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss
@register_model
def register_unet_diffusion_dvae(opt_net, opt):
return DiffusionDVAE(**opt_net['kwargs'])
'''
class DiffusionDVAE(nn.Module):
def __init__(
self,
model_channels,
num_res_blocks,
in_channels=1,
out_channels=2, # mean and variance
spectrogram_channels=80,
spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform.
dropout=0,
channel_mult=(1, 2, 4, 8, 16, 32, 64),
attention_resolutions=(16,32,64),
conv_resample=True,
dims=1,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
use_new_attention_order=False,
kernel_size=5,
quantize_dim=1024,
num_discrete_codes=8192,
scale_steps=4,
conditioning_inputs_provided=True,
):
'''
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
spec = torch.randn(4, 80, 160)
ts = torch.LongTensor([432, 234, 100, 555])
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
print(model(torch.randn_like(spec), ts, spec)[0].shape)

View File

@ -1,179 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
import torch.distributed as distributed
from models.vqvae.vqvae import ResBlock, Quantize
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, padding):
super().__init__()
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.LeakyReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
):
super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.LeakyReLU(inplace=True))
if stride == 4:
blocks.extend(
[
UpsampleConv(channel, channel // 2, 5, padding=2),
nn.LeakyReLU(inplace=True),
UpsampleConv(
channel // 2, out_channel, 5, padding=2
),
]
)
elif stride == 2:
blocks.append(
UpsampleConv(channel, out_channel, 5, padding=2)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE3(nn.Module):
def __init__(
self,
in_channel=3,
channel=128,
n_res_block=2,
n_res_channel=32,
codebook_dim=64,
codebook_size=512,
decay=0.99,
):
super().__init__()
self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1),
nn.LeakyReLU(inplace=True)])
self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size)
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, 5, padding=2
)
self.dec = Decoder(
codebook_dim + codebook_dim,
32,
channel,
n_res_block,
n_res_channel,
stride=4,
)
self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b)
return dec, diff
def encode(self, input):
fea = self.initial_conv(input)
enc_b = checkpoint(self.enc_b, fea)
enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = checkpoint(self.dec, quant)
dec = checkpoint(self.final_conv, dec)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
@register_model
def register_vqvae3(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE3(**kw)
if __name__ == '__main__':
v = VQVAE3()
print(v(torch.randn(1,3,128,128))[0].shape)

View File

@ -1,279 +0,0 @@
import os
from time import time
import torch
import torchvision
import torch.distributed as distributed
from torch import nn
from tqdm import tqdm
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
convert_conv_net_state_dict_to_switched_conv
from models.vqvae.vqvae import ResBlock, Quantize
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, padding, cfg):
super().__init__()
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth=cfg['breadth'], include_coupler=True,
coupler_mode=cfg['mode'], coupler_dim_in=in_filters, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled'])
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, cfg):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 5, breadth=cfg['breadth'], stride=2, include_coupler=True,
coupler_mode=cfg['mode'], coupler_dim_in=channel // 2, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']),
nn.LeakyReLU(inplace=True),
SwitchedConvHardRouting(channel, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'],
coupler_dim_in=channel, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'],
coupler_dim_in=channel // 2, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.LeakyReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, cfg
):
super().__init__()
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'],
coupler_dim_in=in_channel, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled'])]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.LeakyReLU(inplace=True))
if stride == 4:
blocks.extend(
[
UpsampleConv(channel, channel // 2, 5, padding=2, cfg=cfg),
nn.LeakyReLU(inplace=True),
UpsampleConv(
channel // 2, out_channel, 5, padding=2, cfg=cfg
),
]
)
elif stride == 2:
blocks.append(
UpsampleConv(channel, out_channel, 5, padding=2, cfg=cfg)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE3HardSwitch(nn.Module):
def __init__(
self,
in_channel=3,
channel=128,
n_res_block=2,
n_res_channel=32,
codebook_dim=64,
codebook_size=512,
decay=0.99,
cfg={'mode':'standard', 'breadth':16, 'hard_enabled': True, 'dropout': 0.4}
):
super().__init__()
self.cfg = cfg
self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1),
nn.LeakyReLU(inplace=True)])
self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4, cfg=cfg)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, cfg=cfg)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, cfg=cfg
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size)
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, 5, padding=2, cfg=cfg
)
self.dec = Decoder(
codebook_dim + codebook_dim,
32,
channel,
n_res_block,
n_res_channel,
stride=4,
cfg=cfg
)
self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b)
return dec, diff
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
from matplotlib import cm
magnitude, indices = torch.topk(attention_out, 3, dim=1)
indices = indices.cpu()
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
img = img.permute((0, 3, 1, 2))
torchvision.utils.save_image(img, output_file)
def visual_dbg(self, step, path):
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
for i, c in enumerate(convs):
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, 16)
def get_debug_values(self, step, __):
switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]),
('enc_b_blk4', self.enc_b.blocks[4]),
('enc_t_blk2', self.enc_t.blocks[2]),
('dec_t_blk0', self.dec_t.blocks[0]),
('dec_t_blk-1', self.dec_t.blocks[-1].conv),
('dec_blk0', self.dec.blocks[0]),
('dec_blk-1', self.dec.blocks[-1].conv),
('dec_blk-3', self.dec.blocks[-3].conv)]
logs = {}
for name, swc in switched_convs:
logs[f'{name}_histogram_switch_usage'] = swc.latest_masks
return logs
def encode(self, input):
fea = self.initial_conv(input)
enc_b = checkpoint(self.enc_b, fea)
enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = checkpoint(self.dec, quant)
dec = checkpoint(self.final_conv, dec)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
def convert_weights(weights_file):
sd = torch.load(weights_file)
from models.vqvae.vqvae_3 import VQVAE3
std_model = VQVAE3()
std_model.load_state_dict(sd)
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 16, ['quantize_conv_t', 'quantize_conv_b',
'enc_b.blocks.0', 'enc_t.blocks.0',
'conv.1', 'conv.3', 'initial_conv', 'final_conv'])
torch.save(nsd, "converted.pth")
@register_model
def register_vqvae3_hard_switch(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
vq = VQVAE3HardSwitch(**kw)
if distributed.is_initialized() and distributed.get_world_size() > 1:
vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq)
return vq
def performance_test():
# For breadth=32:
# Custom_cuda_naive: 15.4
# Torch_native: 29.2s
#
# For breadth=8
# Custom_cuda_naive: 9.8
# Torch_native: 10s
cfg = {
'mode': 'lambda',
'breadth': 16,
'hard_enabled': True,
'dropout': 0,
}
net = VQVAE3HardSwitch(cfg=cfg).to('cuda').double()
cfg['hard_enabled'] = False
netO = VQVAE3HardSwitch(cfg=cfg).double()
netO.load_state_dict(net.state_dict())
netO = netO.cpu()
loss = nn.L1Loss()
opt = torch.optim.Adam(net.parameters(), lr=1e-4)
started = time()
for j in tqdm(range(10)):
inp = torch.rand((4, 3, 64, 64), device='cuda', dtype=torch.double)
res = net(inp)[0]
l = loss(res, inp)
l.backward()
res2 = netO(inp.cpu())[0]
l = loss(res2, inp.cpu())
l.backward()
for p, op in zip(net.parameters(), netO.parameters()):
diff = p.grad.cpu() - op.grad
print(diff.max())
opt.step()
net.zero_grad()
print("Elapsed: ", (time()-started))
if __name__ == '__main__':
#v = VQVAE3HardSwitch()
#print(v(torch.randn(1,3,128,128))[0].shape)
#convert_weights("../../../experiments/vqvae_base.pth")
performance_test()

View File

@ -1,179 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
import torch.distributed as distributed
from models.vqvae.vqvae import ResBlock, Quantize
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, padding):
super().__init__()
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.LeakyReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
):
super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.LeakyReLU(inplace=True))
if stride == 4:
blocks.extend(
[
UpsampleConv(channel, channel // 2, 5, padding=2),
nn.LeakyReLU(inplace=True),
UpsampleConv(
channel // 2, out_channel, 5, padding=2
),
]
)
elif stride == 2:
blocks.append(
UpsampleConv(channel, out_channel, 5, padding=2)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE3(nn.Module):
def __init__(
self,
in_channel=3,
channel=128,
n_res_block=2,
n_res_channel=32,
codebook_dim=64,
codebook_size=512,
decay=0.99,
):
super().__init__()
self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1),
nn.LeakyReLU(inplace=True)])
self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size)
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, 5, padding=2
)
self.dec = Decoder(
codebook_dim + codebook_dim,
32,
channel,
n_res_block,
n_res_channel,
stride=4,
)
self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b)
return dec, diff
def encode(self, input):
fea = self.initial_conv(input)
enc_b = checkpoint(self.enc_b, fea)
enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = checkpoint(self.dec, quant)
dec = checkpoint(self.final_conv, dec)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
@register_model
def register_vqvae3(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE3(**kw)
if __name__ == '__main__':
v = VQVAE3()
print(v(torch.randn(1,3,128,128))[0].shape)

View File

@ -1,149 +0,0 @@
import math
from typing import Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import torch.distributed as distributed
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
layer_norm_eps=1e-5, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True,
**factory_kwargs)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm1 = nn.BatchNorm1d(d_model)
self.norm2 = nn.BatchNorm1d(d_model)
self.activation = nn.ReLU()
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + src2
src = self.norm1(src.permute(0,2,1)).permute(0,2,1)
src2 = self.linear2(self.activation(self.linear1(src)))
src = src + src2
src = self.norm2(src.permute(0,2,1)).permute(0,2,1)
return src
class Encoder(nn.Module):
def __init__(self, in_channel, channel, output_breadth, num_layers=8, compression_factor=8):
super().__init__()
self.compression_factor = compression_factor
self.pre_conv_stack = nn.Sequential(nn.Conv1d(in_channel, channel//4, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(channel//4, channel//2, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv1d(channel//2, channel//2, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(channel//2, channel, kernel_size=3, stride=2, padding=1))
self.norm1 = nn.BatchNorm1d(channel)
self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth//4)
self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers)
def forward(self, input):
x = self.norm1(self.pre_conv_stack(input)).permute(0,2,1)
x = self.positional_embeddings(x)
x = self.encode(x)
return x[:,:input.shape[2]//self.compression_factor,:]
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, output_breadth, num_layers=6
):
super().__init__()
self.initial_conv = nn.Conv1d(in_channel, channel, kernel_size=1)
self.expand = output_breadth
self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth)
self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers)
self.final_conv_stack = nn.Sequential(nn.Conv1d(channel, channel, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(channel, out_channel, kernel_size=3, padding=1))
def forward(self, input):
x = self.initial_conv(input.permute(0,2,1)).permute(0,2,1)
x = nn.functional.pad(x, (0,0,0, self.expand-input.shape[1]))
x = self.positional_embeddings(x)
x = self.encode(x).permute(0,2,1)
return self.final_conv_stack(x)
class VQVAE(nn.Module):
def __init__(
self,
data_channels=1,
channel=256,
codebook_dim=256,
codebook_size=512,
breadth=80,
):
super().__init__()
self.enc = Encoder(data_channels, channel, breadth)
self.quantize_dense = nn.Linear(channel, codebook_dim)
self.quantize = Quantize(codebook_dim, codebook_size)
self.dec = Decoder(codebook_dim, data_channels, channel, breadth)
def forward(self, input):
input = input.unsqueeze(1)
quant, diff, _ = self.encode(input)
dec = checkpoint(self.dec, quant)
dec = dec.squeeze(1)
return dec, diff
def encode(self, input):
enc = checkpoint(self.enc, input)
quant = self.quantize_dense(enc)
quant, diff, id = self.quantize(quant)
diff = diff.unsqueeze(0)
return quant, diff, id
@register_model
def register_vqvae_xform_audio(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
vq = VQVAE(**kw)
return vq
if __name__ == '__main__':
model = VQVAE()
res=model(torch.randn(4,80))
print(res[0].shape)

View File

@ -1,265 +0,0 @@
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and
# a "standard" upsampler rather than ConvTranspose.
import torch
from torch import nn
from torch.nn import functional as F
import torch.distributed as distributed
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, padding):
super().__init__()
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input):
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(embed_onehot_sum)
distributed.all_reduce(embed_sum)
self.cluster_size.data.mul_(self.decay).add_(
embed_onehot_sum, alpha=1 - self.decay
)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
class ResBlock(nn.Module):
def __init__(self, in_channel, channel):
super().__init__()
self.conv = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(in_channel, channel, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, in_channel, 1),
)
def forward(self, input):
out = self.conv(input)
out += input
return out
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
):
super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
if stride == 4:
blocks.extend(
[
UpsampleConv(channel, channel // 2, 5, padding=2),
nn.ReLU(inplace=True),
UpsampleConv(
channel // 2, out_channel, 5, padding=2
),
]
)
elif stride == 2:
blocks.append(
UpsampleConv(channel, out_channel, 5, padding=2)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE(nn.Module):
def __init__(
self,
in_channel=3,
channel=128,
n_res_block=2,
n_res_channel=32,
codebook_dim=64,
codebook_size=512,
decay=0.99,
):
super().__init__()
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, 5, padding=2
)
self.dec = Decoder(
codebook_dim + codebook_dim,
in_channel,
channel,
n_res_block,
n_res_channel,
stride=4,
)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b)
return dec, diff
def encode(self, input):
enc_b = checkpoint(self.enc_b, input)
enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = checkpoint(self.dec, quant)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
@register_model
def register_vqvae_normalized(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE(**kw)
if __name__ == '__main__':
v = VQVAE()
print(v(torch.randn(1,3,128,128))[0].shape)

View File

@ -1,293 +0,0 @@
import os
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import torch.distributed as distributed
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
convert_conv_net_state_dict_to_switched_conv
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
super().__init__()
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input):
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(embed_onehot_sum)
distributed.all_reduce(embed_sum)
self.cluster_size.data.mul_(self.decay).add_(
embed_onehot_sum, alpha=1 - self.decay
)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
class ResBlock(nn.Module):
def __init__(self, in_channel, channel, breadth):
super().__init__()
self.conv = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(in_channel, channel, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, in_channel, 1),
)
def forward(self, input):
out = self.conv(input)
out += input
return out
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel, breadth))
blocks.append(nn.ReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth
):
super().__init__()
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel, breadth))
blocks.append(nn.ReLU(inplace=True))
if stride == 4:
blocks.extend(
[
UpsampleConv(channel, channel // 2, breadth, 5, padding=2),
nn.ReLU(inplace=True),
UpsampleConv(
channel // 2, out_channel, breadth, 5, padding=2
),
]
)
elif stride == 2:
blocks.append(
UpsampleConv(channel, out_channel, breadth, 5, padding=2)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE(nn.Module):
def __init__(
self,
in_channel=3,
channel=128,
n_res_block=2,
n_res_channel=32,
codebook_dim=64,
codebook_size=512,
decay=0.99,
breadth=8,
):
super().__init__()
self.breadth = breadth
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, breadth, 5, padding=2
)
self.dec = Decoder(
codebook_dim + codebook_dim,
in_channel,
channel,
n_res_block,
n_res_channel,
stride=4,
breadth=breadth
)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b)
return dec, diff
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
from matplotlib import cm
magnitude, indices = torch.topk(attention_out, 3, dim=1)
indices = indices.cpu()
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
img = img.permute((0, 3, 1, 2))
torchvision.utils.save_image(img, output_file)
def visual_dbg(self, step, path):
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
for i, c in enumerate(convs):
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth)
def get_debug_values(self, step, __):
switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]),
('enc_b_blk4', self.enc_b.blocks[4]),
('enc_t_blk2', self.enc_t.blocks[2]),
('dec_t_blk0', self.dec_t.blocks[0]),
('dec_t_blk-1', self.dec_t.blocks[-1].conv),
('dec_blk0', self.dec.blocks[0]),
('dec_blk-1', self.dec.blocks[-1].conv),
('dec_blk-3', self.dec.blocks[-3].conv)]
logs = {}
for name, swc in switched_convs:
logs[f'{name}_histogram_switch_usage'] = swc.latest_masks
return logs
def encode(self, input):
enc_b = checkpoint(self.enc_b, input)
enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = checkpoint(self.dec, quant)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
def convert_weights(weights_file):
sd = torch.load(weights_file)
import models.vqvae.vqvae_no_conv_transpose as stdvq
std_model = stdvq.VQVAE()
std_model.load_state_dict(sd)
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 8, ['quantize_conv_t', 'quantize_conv_b',
'enc_b.blocks.0', 'enc_t.blocks.0',
'conv.1', 'conv.3'])
torch.save(nsd, "converted.pth")
@register_model
def register_vqvae_norm_hard_switched_conv_lambda(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE(**kw)
if __name__ == '__main__':
v = VQVAE(breadth=8).cuda()
print(v(torch.randn(1,3,128,128).cuda())[0].shape)
#convert_weights("../../../experiments/50000_generator.pth")

View File

@ -1,276 +0,0 @@
import os
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import torch.distributed as distributed
from models.switched_conv.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
super().__init__()
self.conv = SwitchedConv(in_filters, out_filters, kernel_size, breadth, padding=padding, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input):
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(embed_onehot_sum)
distributed.all_reduce(embed_sum)
self.cluster_size.data.mul_(self.decay).add_(
embed_onehot_sum, alpha=1 - self.decay
)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
class ResBlock(nn.Module):
def __init__(self, in_channel, channel, breadth):
super().__init__()
self.conv = nn.Sequential(
nn.ReLU(inplace=True),
SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
nn.ReLU(inplace=True),
SwitchedConv(channel, in_channel, 1, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel),
)
def forward(self, input):
out = self.conv(input)
out += input
return out
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth):
super().__init__()
if stride == 4:
blocks = [
SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
nn.ReLU(inplace=True),
SwitchedConv(channel // 2, channel, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2),
nn.ReLU(inplace=True),
SwitchedConv(channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel),
]
elif stride == 2:
blocks = [
SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel),
nn.ReLU(inplace=True),
SwitchedConv(channel // 2, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel, breadth))
blocks.append(nn.ReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth
):
super().__init__()
blocks = [SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel, breadth))
blocks.append(nn.ReLU(inplace=True))
if stride == 4:
blocks.extend(
[
UpsampleConv(channel, channel // 2, breadth, 5, padding=2),
nn.ReLU(inplace=True),
UpsampleConv(
channel // 2, out_channel, breadth, 5, padding=2
),
]
)
elif stride == 2:
blocks.append(
UpsampleConv(channel, out_channel, breadth, 5, padding=2)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE(nn.Module):
def __init__(
self,
in_channel=3,
channel=128,
n_res_block=2,
n_res_channel=32,
codebook_dim=64,
codebook_size=512,
decay=0.99,
breadth=4,
):
super().__init__()
self.breadth = breadth
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth)
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
self.quantize_t = Quantize(codebook_dim, codebook_size)
self.dec_t = Decoder(
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, breadth, 5, padding=2
)
self.dec = Decoder(
codebook_dim + codebook_dim,
in_channel,
channel,
n_res_block,
n_res_channel,
stride=4,
breadth=breadth
)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
dec = self.decode(quant_t, quant_b)
return dec, diff
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
from matplotlib import cm
magnitude, indices = torch.topk(attention_out, 3, dim=1)
indices = indices.cpu()
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
img = img.permute((0, 3, 1, 2))
torchvision.utils.save_image(img, output_file)
def visual_dbg(self, step, path):
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
for i, c in enumerate(convs):
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth)
def encode(self, input):
enc_b = checkpoint(self.enc_b, input)
enc_t = checkpoint(self.enc_t, enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = checkpoint(self.dec, quant)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
def convert_weights(weights_file):
sd = torch.load(weights_file)
import models.vqvae.vqvae_no_conv_transpose as stdvq
std_model = stdvq.VQVAE()
std_model.load_state_dict(sd)
nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 4, ['quantize_conv_t', 'quantize_conv_b'])
torch.save(nsd, "converted.pth")
@register_model
def register_vqvae_norm_switched_conv_lambda(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE(**kw)
if __name__ == '__main__':
#v = VQVAE()
#print(v(torch.randn(1,3,128,128))[0].shape)
convert_weights("../../../experiments/4000_generator.pth")

View File

@ -55,7 +55,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-17.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-20.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_audio_clips_with_quantizer_compression.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()