46e9f62be0
This is a diffusion network that uses both a LQ image and a reference sample HQ image that is compressed into a latent vector to perform upsampling The hope is that we can steer the upsampling network with sample images.
841 lines
30 KiB
Python
841 lines
30 KiB
Python
from abc import abstractmethod
|
|
|
|
import math
|
|
from typing import Union, Type, Callable, Optional, List
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch as th
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision # For debugging, not actually used.
|
|
from kornia.augmentation import ColorJitter
|
|
from torch import Tensor
|
|
from torchvision.models import resnet50
|
|
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
|
|
|
|
from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32
|
|
from models.diffusion.nn import (
|
|
conv_nd,
|
|
linear,
|
|
avg_pool_nd,
|
|
zero_module,
|
|
normalization,
|
|
timestep_embedding,
|
|
)
|
|
from trainer.networks import register_model
|
|
from utils.util import checkpoint
|
|
|
|
|
|
class AttentionPool2d(nn.Module):
|
|
"""
|
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
spacial_dim: int,
|
|
embed_dim: int,
|
|
num_heads_channels: int,
|
|
output_dim: int = None,
|
|
):
|
|
super().__init__()
|
|
self.positional_embedding = nn.Parameter(
|
|
th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
|
)
|
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
self.num_heads = embed_dim // num_heads_channels
|
|
self.attention = QKVAttention(self.num_heads)
|
|
|
|
def forward(self, x):
|
|
b, c, *_spatial = x.shape
|
|
x = x.reshape(b, c, -1) # NC(HW)
|
|
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
|
x = self.qkv_proj(x)
|
|
x = self.attention(x)
|
|
x = self.c_proj(x)
|
|
return x[:, :, 0]
|
|
|
|
|
|
class TimestepBlock(nn.Module):
|
|
"""
|
|
Any module where forward() takes timestep embeddings as a second argument.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def forward(self, x, emb):
|
|
"""
|
|
Apply the module to `x` given `emb` timestep embeddings.
|
|
"""
|
|
|
|
|
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
"""
|
|
A sequential module that passes timestep embeddings to the children that
|
|
support it as an extra input.
|
|
"""
|
|
|
|
def forward(self, x, emb):
|
|
for layer in self:
|
|
if isinstance(layer, TimestepBlock):
|
|
x = layer(x, emb)
|
|
else:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
"""
|
|
An upsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
upsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
if use_conv:
|
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
if self.dims == 3:
|
|
x = F.interpolate(
|
|
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
|
)
|
|
else:
|
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
|
if self.use_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
"""
|
|
A downsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
downsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
stride = 2 if dims != 3 else (1, 2, 2)
|
|
if use_conv:
|
|
self.op = conv_nd(
|
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
|
)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
return self.op(x)
|
|
|
|
|
|
class ResBlock(TimestepBlock):
|
|
"""
|
|
A residual block that can optionally change the number of channels.
|
|
|
|
:param channels: the number of input channels.
|
|
:param emb_channels: the number of timestep embedding channels.
|
|
:param dropout: the rate of dropout.
|
|
:param out_channels: if specified, the number of out channels.
|
|
:param use_conv: if True and out_channels is specified, use a spatial
|
|
convolution instead of a smaller 1x1 convolution to change the
|
|
channels in the skip connection.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
:param up: if True, use this block for upsampling.
|
|
:param down: if True, use this block for downsampling.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
emb_channels,
|
|
dropout,
|
|
out_channels=None,
|
|
use_conv=False,
|
|
use_scale_shift_norm=False,
|
|
dims=2,
|
|
up=False,
|
|
down=False,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.emb_channels = emb_channels
|
|
self.dropout = dropout
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.use_scale_shift_norm = use_scale_shift_norm
|
|
|
|
self.in_layers = nn.Sequential(
|
|
normalization(channels),
|
|
nn.SiLU(),
|
|
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
|
)
|
|
|
|
self.updown = up or down
|
|
|
|
if up:
|
|
self.h_upd = Upsample(channels, False, dims)
|
|
self.x_upd = Upsample(channels, False, dims)
|
|
elif down:
|
|
self.h_upd = Downsample(channels, False, dims)
|
|
self.x_upd = Downsample(channels, False, dims)
|
|
else:
|
|
self.h_upd = self.x_upd = nn.Identity()
|
|
|
|
self.emb_layers = nn.Sequential(
|
|
nn.SiLU(),
|
|
linear(
|
|
emb_channels,
|
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
|
),
|
|
)
|
|
self.out_layers = nn.Sequential(
|
|
normalization(self.out_channels),
|
|
nn.SiLU(),
|
|
nn.Dropout(p=dropout),
|
|
zero_module(
|
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
|
),
|
|
)
|
|
|
|
if self.out_channels == channels:
|
|
self.skip_connection = nn.Identity()
|
|
elif use_conv:
|
|
self.skip_connection = conv_nd(
|
|
dims, channels, self.out_channels, 3, padding=1
|
|
)
|
|
else:
|
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
|
|
|
def forward(self, x, emb):
|
|
"""
|
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
|
|
|
:param x: an [N x C x ...] Tensor of features.
|
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
|
:return: an [N x C x ...] Tensor of outputs.
|
|
"""
|
|
return checkpoint(
|
|
self._forward, x, emb
|
|
)
|
|
|
|
def _forward(self, x, emb):
|
|
if self.updown:
|
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
h = in_rest(x)
|
|
h = self.h_upd(h)
|
|
x = self.x_upd(x)
|
|
h = in_conv(h)
|
|
else:
|
|
h = self.in_layers(x)
|
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
|
while len(emb_out.shape) < len(h.shape):
|
|
emb_out = emb_out[..., None]
|
|
if self.use_scale_shift_norm:
|
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
|
scale, shift = th.chunk(emb_out, 2, dim=1)
|
|
h = out_norm(h) * (1 + scale) + shift
|
|
h = out_rest(h)
|
|
else:
|
|
h = h + emb_out
|
|
h = self.out_layers(h)
|
|
return self.skip_connection(x) + h
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
"""
|
|
An attention block that allows spatial positions to attend to each other.
|
|
|
|
Originally ported from here, but adapted to the N-d case.
|
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
num_heads=1,
|
|
num_head_channels=-1,
|
|
use_new_attention_order=False,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
if num_head_channels == -1:
|
|
self.num_heads = num_heads
|
|
else:
|
|
assert (
|
|
channels % num_head_channels == 0
|
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
self.num_heads = channels // num_head_channels
|
|
self.norm = normalization(channels)
|
|
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
|
if use_new_attention_order:
|
|
# split qkv before split heads
|
|
self.attention = QKVAttention(self.num_heads)
|
|
else:
|
|
# split heads before split qkv
|
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
|
|
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
|
|
|
def forward(self, x):
|
|
return checkpoint(self._forward, x)
|
|
|
|
def _forward(self, x):
|
|
b, c, *spatial = x.shape
|
|
x = x.reshape(b, c, -1)
|
|
qkv = self.qkv(self.norm(x))
|
|
h = self.attention(qkv)
|
|
h = self.proj_out(h)
|
|
return (x + h).reshape(b, c, *spatial)
|
|
|
|
|
|
def count_flops_attn(model, _x, y):
|
|
"""
|
|
A counter for the `thop` package to count the operations in an
|
|
attention operation.
|
|
Meant to be used like:
|
|
macs, params = thop.profile(
|
|
model,
|
|
inputs=(inputs, timestamps),
|
|
custom_ops={QKVAttention: QKVAttention.count_flops},
|
|
)
|
|
"""
|
|
b, c, *spatial = y[0].shape
|
|
num_spatial = int(np.prod(spatial))
|
|
# We perform two matmuls with the same number of ops.
|
|
# The first computes the weight matrix, the second computes
|
|
# the combination of the value vectors.
|
|
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
|
model.total_ops += th.DoubleTensor([matmul_ops])
|
|
|
|
|
|
class QKVAttentionLegacy(nn.Module):
|
|
"""
|
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
"""
|
|
|
|
def __init__(self, n_heads):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, qkv):
|
|
"""
|
|
Apply QKV attention.
|
|
|
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
"""
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = th.einsum(
|
|
"bct,bcs->bts", q * scale, k * scale
|
|
) # More stable with f16 than dividing afterwards
|
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
a = th.einsum("bts,bcs->bct", weight, v)
|
|
return a.reshape(bs, -1, length)
|
|
|
|
@staticmethod
|
|
def count_flops(model, _x, y):
|
|
return count_flops_attn(model, _x, y)
|
|
|
|
|
|
class QKVAttention(nn.Module):
|
|
"""
|
|
A module which performs QKV attention and splits in a different order.
|
|
"""
|
|
|
|
def __init__(self, n_heads):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, qkv):
|
|
"""
|
|
Apply QKV attention.
|
|
|
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
"""
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.chunk(3, dim=1)
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = th.einsum(
|
|
"bct,bcs->bts",
|
|
(q * scale).view(bs * self.n_heads, ch, length),
|
|
(k * scale).view(bs * self.n_heads, ch, length),
|
|
) # More stable with f16 than dividing afterwards
|
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
return a.reshape(bs, -1, length)
|
|
|
|
@staticmethod
|
|
def count_flops(model, _x, y):
|
|
return count_flops_attn(model, _x, y)
|
|
|
|
|
|
class UNetModel(nn.Module):
|
|
"""
|
|
The full UNet model with attention and timestep embedding.
|
|
|
|
:param in_channels: channels in the input Tensor.
|
|
:param model_channels: base channel count for the model.
|
|
:param out_channels: channels in the output Tensor.
|
|
:param num_res_blocks: number of residual blocks per downsample.
|
|
:param attention_resolutions: a collection of downsample rates at which
|
|
attention will take place. May be a set, list, or tuple.
|
|
For example, if this contains 4, then at 4x downsampling, attention
|
|
will be used.
|
|
:param dropout: the dropout probability.
|
|
:param channel_mult: channel multiplier for each level of the UNet.
|
|
:param conv_resample: if True, use learned convolutions for upsampling and
|
|
downsampling.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
:param num_classes: if specified (as an int), then this model will be
|
|
class-conditional with `num_classes` classes.
|
|
:param num_heads: the number of attention heads in each attention layer.
|
|
:param num_heads_channels: if specified, ignore num_heads and instead use
|
|
a fixed channel width per attention head.
|
|
:param num_heads_upsample: works with num_heads to set a different number
|
|
of heads for upsampling. Deprecated.
|
|
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
|
:param resblock_updown: use residual blocks for up/downsampling.
|
|
:param use_new_attention_order: use a different attention pattern for potentially
|
|
increased efficiency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
image_size,
|
|
in_channels,
|
|
model_channels,
|
|
out_channels,
|
|
num_res_blocks,
|
|
attention_resolutions,
|
|
dropout=0,
|
|
channel_mult=(1, 2, 4, 8),
|
|
conv_resample=True,
|
|
dims=2,
|
|
num_classes=None,
|
|
use_fp16=False,
|
|
num_heads=1,
|
|
num_head_channels=-1,
|
|
num_heads_upsample=-1,
|
|
use_scale_shift_norm=False,
|
|
resblock_updown=False,
|
|
use_new_attention_order=False,
|
|
):
|
|
super().__init__()
|
|
|
|
if num_heads_upsample == -1:
|
|
num_heads_upsample = num_heads
|
|
|
|
self.image_size = image_size
|
|
self.in_channels = in_channels
|
|
self.model_channels = model_channels
|
|
self.out_channels = out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
self.attention_resolutions = attention_resolutions
|
|
self.dropout = dropout
|
|
self.channel_mult = channel_mult
|
|
self.conv_resample = conv_resample
|
|
self.num_classes = num_classes
|
|
self.dtype = th.float16 if use_fp16 else th.float32
|
|
self.num_heads = num_heads
|
|
self.num_head_channels = num_head_channels
|
|
self.num_heads_upsample = num_heads_upsample
|
|
|
|
time_embed_dim = model_channels * 4
|
|
self.time_embed = nn.Sequential(
|
|
linear(model_channels, time_embed_dim),
|
|
nn.SiLU(),
|
|
linear(time_embed_dim, time_embed_dim),
|
|
)
|
|
|
|
if self.num_classes is not None:
|
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
|
|
|
self.input_blocks = nn.ModuleList(
|
|
[
|
|
TimestepEmbedSequential(
|
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
|
)
|
|
]
|
|
)
|
|
self._feature_size = model_channels
|
|
input_block_chans = [model_channels]
|
|
ch = model_channels
|
|
ds = 1
|
|
for level, mult in enumerate(channel_mult):
|
|
for _ in range(num_res_blocks):
|
|
layers = [
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
out_channels=mult * model_channels,
|
|
dims=dims,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
)
|
|
]
|
|
ch = mult * model_channels
|
|
if ds in attention_resolutions:
|
|
layers.append(
|
|
AttentionBlock(
|
|
ch,
|
|
num_heads=num_heads,
|
|
num_head_channels=num_head_channels,
|
|
use_new_attention_order=use_new_attention_order,
|
|
)
|
|
)
|
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
self._feature_size += ch
|
|
input_block_chans.append(ch)
|
|
if level != len(channel_mult) - 1:
|
|
out_ch = ch
|
|
self.input_blocks.append(
|
|
TimestepEmbedSequential(
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
out_channels=out_ch,
|
|
dims=dims,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
down=True,
|
|
)
|
|
if resblock_updown
|
|
else Downsample(
|
|
ch, conv_resample, dims=dims, out_channels=out_ch
|
|
)
|
|
)
|
|
)
|
|
ch = out_ch
|
|
input_block_chans.append(ch)
|
|
ds *= 2
|
|
self._feature_size += ch
|
|
|
|
self.latent_join_reduce = ResBlock(ch*2, time_embed_dim, dropout, out_channels=ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm)
|
|
self.middle_block = TimestepEmbedSequential(
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
dims=dims,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
),
|
|
AttentionBlock(
|
|
ch,
|
|
num_heads=num_heads,
|
|
num_head_channels=num_head_channels,
|
|
use_new_attention_order=use_new_attention_order,
|
|
),
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
dims=dims,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
),
|
|
)
|
|
self._feature_size += ch
|
|
|
|
self.output_blocks = nn.ModuleList([])
|
|
for level, mult in list(enumerate(channel_mult))[::-1]:
|
|
for i in range(num_res_blocks + 1):
|
|
ich = input_block_chans.pop()
|
|
layers = [
|
|
ResBlock(
|
|
ch + ich,
|
|
time_embed_dim,
|
|
dropout,
|
|
out_channels=model_channels * mult,
|
|
dims=dims,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
)
|
|
]
|
|
ch = model_channels * mult
|
|
if ds in attention_resolutions:
|
|
layers.append(
|
|
AttentionBlock(
|
|
ch,
|
|
num_heads=num_heads_upsample,
|
|
num_head_channels=num_head_channels,
|
|
use_new_attention_order=use_new_attention_order,
|
|
)
|
|
)
|
|
if level and i == num_res_blocks:
|
|
out_ch = ch
|
|
layers.append(
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
out_channels=out_ch,
|
|
dims=dims,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
up=True,
|
|
)
|
|
if resblock_updown
|
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
|
)
|
|
ds //= 2
|
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
|
self._feature_size += ch
|
|
|
|
self.out = nn.Sequential(
|
|
normalization(ch),
|
|
nn.SiLU(),
|
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
|
)
|
|
|
|
def convert_to_fp16(self):
|
|
"""
|
|
Convert the torso of the model to float16.
|
|
"""
|
|
self.input_blocks.apply(convert_module_to_f16)
|
|
self.middle_block.apply(convert_module_to_f16)
|
|
self.output_blocks.apply(convert_module_to_f16)
|
|
|
|
def convert_to_fp32(self):
|
|
"""
|
|
Convert the torso of the model to float32.
|
|
"""
|
|
self.input_blocks.apply(convert_module_to_f32)
|
|
self.middle_block.apply(convert_module_to_f32)
|
|
self.output_blocks.apply(convert_module_to_f32)
|
|
|
|
def forward(self, x, latent, timesteps, y=None):
|
|
"""
|
|
Apply the model to an input batch.
|
|
|
|
:param x: an [N x C x ...] Tensor of inputs.
|
|
:param timesteps: a 1-D batch of timesteps.
|
|
:param y: an [N] Tensor of labels, if class-conditional.
|
|
:return: an [N x C x ...] Tensor of outputs.
|
|
"""
|
|
assert (y is not None) == (
|
|
self.num_classes is not None
|
|
), "must specify y if and only if the model is class-conditional"
|
|
|
|
hs = []
|
|
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
|
|
|
if self.num_classes is not None:
|
|
assert y.shape == (x.shape[0],)
|
|
emb = emb + self.label_emb(y)
|
|
|
|
h = x.type(self.dtype)
|
|
for module in self.input_blocks:
|
|
h = module(h, emb)
|
|
hs.append(h)
|
|
b, c = latent.shape
|
|
h = torch.cat([h, latent.view(b,c,1,1).repeat(1,1,h.shape[-2],h.shape[-1])], dim=1)
|
|
h = self.latent_join_reduce(h, emb)
|
|
h = self.middle_block(h, emb)
|
|
for module in self.output_blocks:
|
|
h = th.cat([h, hs.pop()], dim=1)
|
|
h = module(h, emb)
|
|
h = h.type(x.dtype)
|
|
return self.out(h)
|
|
|
|
|
|
class SuperResModel(UNetModel):
|
|
"""
|
|
A UNetModel that performs super-resolution.
|
|
|
|
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
|
"""
|
|
|
|
def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs):
|
|
self.num_corruptions = 0
|
|
super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs)
|
|
|
|
def forward(self, x, timesteps, latent, low_res=None, corruption_factor=None, **kwargs):
|
|
b, _, new_height, new_width = x.shape
|
|
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
|
if corruption_factor is not None:
|
|
corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width)
|
|
else:
|
|
corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
|
|
upsampled = torch.cat([upsampled, corruption_factor], dim=1)
|
|
x = th.cat([x, upsampled], dim=1)
|
|
res = super().forward(x, latent, timesteps, **kwargs)
|
|
return res
|
|
|
|
|
|
class ResNetEncoder(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
block: Type[Union[BasicBlock, Bottleneck]] = Bottleneck,
|
|
layers: List[int] = [3, 4, 6, 3],
|
|
depth: int = 4,
|
|
output_dim: int = 512,
|
|
zero_init_residual: bool = False,
|
|
groups: int = 1,
|
|
width_per_group: int = 64,
|
|
replace_stride_with_dilation: Optional[List[bool]] = None,
|
|
norm_layer: Optional[Callable[..., nn.Module]] = None
|
|
) -> None:
|
|
super(ResNetEncoder, self).__init__()
|
|
if norm_layer is None:
|
|
norm_layer = nn.BatchNorm2d
|
|
self._norm_layer = norm_layer
|
|
|
|
self.inplanes = 64
|
|
self.dilation = 1
|
|
if replace_stride_with_dilation is None:
|
|
# each element in the tuple indicates if we should replace
|
|
# the 2x2 stride with a dilated convolution instead
|
|
replace_stride_with_dilation = [False, False, False]
|
|
if len(replace_stride_with_dilation) != 3:
|
|
raise ValueError("replace_stride_with_dilation should be None "
|
|
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
|
self.groups = groups
|
|
self.base_width = width_per_group
|
|
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
|
bias=False)
|
|
self.bn1 = norm_layer(self.inplanes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.depth = depth
|
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
|
dilate=replace_stride_with_dilation[0])
|
|
f=128
|
|
if self.depth > 2:
|
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
|
dilate=replace_stride_with_dilation[1])
|
|
f=256
|
|
if self.depth > 3:
|
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
|
dilate=replace_stride_with_dilation[2])
|
|
f=512
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(f * block.expansion, output_dim)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
# Zero-initialize the last BN in each residual branch,
|
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
|
if zero_init_residual:
|
|
for m in self.modules():
|
|
if isinstance(m, Bottleneck):
|
|
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
|
elif isinstance(m, BasicBlock):
|
|
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
|
|
|
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
|
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
|
norm_layer = self._norm_layer
|
|
downsample = None
|
|
previous_dilation = self.dilation
|
|
if dilate:
|
|
self.dilation *= stride
|
|
stride = 1
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
downsample = nn.Sequential(
|
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
|
norm_layer(planes * block.expansion),
|
|
)
|
|
|
|
layers = []
|
|
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
|
self.base_width, previous_dilation, norm_layer))
|
|
self.inplanes = planes * block.expansion
|
|
for _ in range(1, blocks):
|
|
layers.append(block(self.inplanes, planes, groups=self.groups,
|
|
base_width=self.base_width, dilation=self.dilation,
|
|
norm_layer=norm_layer))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
# See note [TorchScript super()]
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
if self.depth > 2:
|
|
x = self.layer3(x)
|
|
if self.depth > 3:
|
|
x = self.layer4(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self._forward_impl(x)
|
|
|
|
|
|
class UnetWithBuiltInLatentEncoder(nn.Module):
|
|
def __init__(self, **kwargs):
|
|
depth_map = {
|
|
256: 4,
|
|
128: 3,
|
|
64: 2
|
|
}
|
|
super().__init__()
|
|
self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']])
|
|
self.lq_jitter = ColorJitter(.05, .05, .05, .05)
|
|
self.unet = SuperResModel(**kwargs)
|
|
|
|
def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs):
|
|
latent = self.encoder(alt_hq)
|
|
low_res = self.lq_jitter((low_res+1)/2)*2-1
|
|
return self.unet(x, timesteps, latent, low_res, **kwargs)
|
|
|
|
|
|
@register_model
|
|
def register_unet_diffusion_latent_guide(opt_net, opt):
|
|
return UnetWithBuiltInLatentEncoder(**opt_net['args'])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
attention_ds = []
|
|
for res in "16,8".split(","):
|
|
attention_ds.append(128 // int(res))
|
|
srm = UnetWithBuiltInLatentEncoder(image_size=64, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4,
|
|
num_heads_upsample=-1, use_scale_shift_norm=True)
|
|
x = torch.randn(1,3,64,64)
|
|
alt_x = torch.randn(1,3,64,64)
|
|
l = torch.randn(1,3,32,32)
|
|
ts = torch.LongTensor([555])
|
|
y = srm(x, ts, alt_x, low_res=l)
|
|
print(y.shape, y.mean(), y.std(), y.min(), y.max())
|