From a6f0f854b987ac724e258af8b042ea4459a571bc Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 17 Oct 2021 22:51:17 -0600 Subject: [PATCH] Fix codes when inferring from dvae --- codes/models/arch_util.py | 535 +++++++++++++++++++--- codes/models/discriminator_vgg_arch.py | 1 - codes/models/gpt_voice/lucidrains_dvae.py | 9 +- 3 files changed, 474 insertions(+), 71 deletions(-) diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 997b925f..2b171fe8 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -1,3 +1,6 @@ +import math +from abc import abstractmethod + import torch import torch.nn as nn import torch.nn.init as init @@ -72,95 +75,491 @@ def default_init_weights(module, scale=1): elif isinstance(m, nn.Linear): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale +""" +Various utilities for neural networks. +""" + +import math + +import torch as th +import torch.nn as nn -class ResidualBlock(nn.Module): - '''Residual block with BN - ---Conv-BN-ReLU-Conv-+- - |________________| - ''' +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) - def __init__(self, nf=64): - super(ResidualBlock, self).__init__() - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.BN1 = nn.BatchNorm2d(nf) - self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.BN2 = nn.BatchNorm2d(nf) - # initialization - initialize_weights([self.conv1, self.conv2], 0.1) +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) def forward(self, x): - identity = x - out = self.lrelu(self.BN1(self.conv1(x))) - out = self.BN2(self.conv2(out)) - return identity + out + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] -class ResidualBlockSpectralNorm(nn.Module): - '''Residual block with Spectral Normalization. - ---SpecConv-ReLU-SpecConv-+- - |________________| - ''' - def __init__(self, nf, total_residual_blocks): - super(ResidualBlockSpectralNorm, self).__init__() - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.conv1 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) - self.conv2 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ - initialize_weights([self.conv1, self.conv2], 1) + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if factor is None: + if dims == 1: + self.factor = 4 + else: + self.factor = 2 + else: + self.factor = factor + if use_conv: + ksize = 3 + pad = 1 + if dims == 1: + ksize = 5 + pad = 2 + self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad) def forward(self, x): - identity = x - out = self.lrelu(self.conv1(x)) - out = self.conv2(out) - return identity + out + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + x = F.interpolate(x, scale_factor=self.factor, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x -class ResidualBlock_noBN(nn.Module): - '''Residual block w/o BN - ---Conv-ReLU-Conv-+- - |________________| - ''' - def __init__(self, nf=64): - super(ResidualBlock_noBN, self).__init__() - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. - # initialization - initialize_weights([self.conv1, self.conv2], 0.1) + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + ksize = 3 + pad = 1 + if dims == 1: + stride = 4 + ksize = 5 + pad = 2 + elif dims == 2: + stride = 2 + else: + stride = (1,2,2) + if factor is not None: + stride = factor + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): - identity = x - out = self.lrelu(self.conv1(x)) - out = self.conv2(out) - return identity + out + assert x.shape[1] == self.channels + return self.op(x) -class ResidualBlockGN(nn.Module): - '''Residual block with GroupNorm - ---Conv-GN-ReLU-Conv-+- - |________________| - ''' +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. - def __init__(self, nf=64): - super(ResidualBlockGN, self).__init__() - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.BN1 = nn.GroupNorm(8, nf) - self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.BN2 = nn.GroupNorm(8, nf) + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ - # initialization - initialize_weights([self.conv1, self.conv2], 0.1) + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + kernel_size=3, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + padding = 1 if kernel_size == 3 else 2 - def forward(self, x): - identity = x - out = self.lrelu(self.BN1(self.conv1(x))) - out = self.BN2(self.conv2(out)) - return identity + out + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, x, emb + ) + + def _forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_new_attention_order=False, + do_checkpoint=True, + ): + super().__init__() + self.channels = channels + self.do_checkpoint = do_checkpoint + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, mask=None): + if self.do_checkpoint: + return checkpoint(self._forward, x, mask) + else: + return self._forward(x, mask) + + def _forward(self, x, mask): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv, mask) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): diff --git a/codes/models/discriminator_vgg_arch.py b/codes/models/discriminator_vgg_arch.py index a9c1572e..234272de 100644 --- a/codes/models/discriminator_vgg_arch.py +++ b/codes/models/discriminator_vgg_arch.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn -from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN import torch.nn.functional as F from trainer.networks import register_model diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 48b19b2e..d9c9eac7 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -179,6 +179,7 @@ class DiscreteVAE(nn.Module): self, img_seq ): + self.log_codes(img_seq) image_embeds = self.codebook.embed_code(img_seq) b, n, d = image_embeds.shape @@ -227,6 +228,12 @@ class DiscreteVAE(nn.Module): # discretization loss disc_loss = self.discrete_loss(soft_codes) + # This is so we can debug the distribution of codes being learned. + self.log_codes(codes) + + return recon_loss, commitment_loss, disc_loss, out + + def log_codes(self, codes): # This is so we can debug the distribution of codes being learned. if self.record_codes and self.internal_step % 50 == 0: codes = codes.flatten() @@ -238,8 +245,6 @@ class DiscreteVAE(nn.Module): self.code_ind = 0 self.internal_step += 1 - return recon_loss, commitment_loss, disc_loss, out - @register_model def register_lucidrains_dvae(opt_net, opt):