Add unet with latent guide

This is a diffusion network that uses both a LQ image
and a reference sample HQ image that is compressed into
a latent vector to perform upsampling

The hope is that we can steer the upsampling network
with sample images.
This commit is contained in:
James Betker 2021-06-26 11:02:58 -06:00
parent 0ded106562
commit 46e9f62be0
3 changed files with 162 additions and 222 deletions

View File

@ -208,21 +208,27 @@ class ImageFolderDataset:
# This assumes the output format generated by the tiled image generation scripts included with DLAS. Specifically, # This assumes the output format generated by the tiled image generation scripts included with DLAS. Specifically,
# all image read by this dataset are assumed to be in subfolders with other tiles from the same source image. When # all image read by this dataset are assumed to be in subfolders with other tiles from the same source image. When
# this option is set, another random image from the same folder is selected and returned as the alt image. # this option is set, another random image from the same folder is selected and returned as the alt image.
sel_path = self.image_paths[item] sel_path = os.path.dirname(self.image_paths[item])
other_images = random.shuffle(os.listdir(sel_path)) other_images = os.listdir(sel_path)
# Assume that the directory contains at least <image>, <ref.jpg>, <centers.pt> # Assume that the directory contains at least <image>, <ref.jpg>, <centers.pt>
if len(other_images) <= 3: try:
alt_hq = hq # This is a fallback in case an alt image can't be found. if len(other_images) <= 3:
else: alt_hq = hq # This is a fallback in case an alt image can't be found.
for oi in other_images: else:
if oi == sel_path or 'ref.' in oi or 'centers.pt' in oi: random.shuffle(other_images)
continue for oi in other_images:
alt_hq = util.read_img(None, oi, rgb=True) if oi == os.path.basename(self.image_paths[item]) or 'ref.' in oi or 'centers.pt' in oi:
alt_hs = self.resize_hq([alt_hq]) continue
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float() alt_hq = util.read_img(None, os.path.join(sel_path, oi), rgb=True)
out_dict['has_alt'] = True alt_hs = self.resize_hq([alt_hq])
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
except:
alt_hq = hq
print(f"Error with {self.image_paths[item]}")
out_dict['has_alt'] = True
out_dict['alt_hq'] = alt_hq out_dict['alt_hq'] = alt_hq
if not self.skip_lq: if not self.skip_lq:
lqs, ent = self.synthesize_lq(for_lq) lqs, ent = self.synthesize_lq(for_lq)
ls = lqs[0] ls = lqs[0]
@ -263,13 +269,14 @@ if __name__ == '__main__':
'scale': 2, 'scale': 2,
'corrupt_before_downsize': True, 'corrupt_before_downsize': True,
'fetch_alt_image': False, 'fetch_alt_image': False,
'fetch_alt_tiled_image': True,
'disable_flip': True, 'disable_flip': True,
'fixed_corruptions': [ 'jpeg-medium' ], 'fixed_corruptions': [ 'jpeg-medium' ],
'num_corrupts_per_image': 0, 'num_corrupts_per_image': 0,
'corruption_blur_scale': 0 'corruption_blur_scale': 0
} }
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64) ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64)
import os import os
output_path = 'F:\\tmp' output_path = 'F:\\tmp'
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)

View File

@ -1,6 +1,7 @@
from abc import abstractmethod from abc import abstractmethod
import math import math
from typing import Union, Type, Callable, Optional, List
import numpy as np import numpy as np
import torch import torch
@ -8,6 +9,10 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision # For debugging, not actually used. import torchvision # For debugging, not actually used.
from kornia.augmentation import ColorJitter
from torch import Tensor
from torchvision.models import resnet50
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32 from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32
from models.diffusion.nn import ( from models.diffusion.nn import (
@ -677,231 +682,159 @@ class SuperResModel(UNetModel):
corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
upsampled = torch.cat([upsampled, corruption_factor], dim=1) upsampled = torch.cat([upsampled, corruption_factor], dim=1)
x = th.cat([x, upsampled], dim=1) x = th.cat([x, upsampled], dim=1)
res = super().forward(x, timesteps, latent, **kwargs) res = super().forward(x, latent, timesteps, **kwargs)
return res return res
class EncoderUNetModel(nn.Module): class ResNetEncoder(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__( def __init__(
self, self,
image_size, block: Type[Union[BasicBlock, Bottleneck]] = Bottleneck,
in_channels, layers: List[int] = [3, 4, 6, 3],
model_channels, depth: int = 4,
out_channels, output_dim: int = 512,
num_res_blocks, zero_init_residual: bool = False,
attention_resolutions, groups: int = 1,
dropout=0, width_per_group: int = 64,
channel_mult=(1, 2, 4, 8), replace_stride_with_dilation: Optional[List[bool]] = None,
conv_resample=True, norm_layer: Optional[Callable[..., nn.Module]] = None
dims=2, ) -> None:
use_fp16=False, super(ResNetEncoder, self).__init__()
num_heads=1, if norm_layer is None:
num_head_channels=-1, norm_layer = nn.BatchNorm2d
num_heads_upsample=-1, self._norm_layer = norm_layer
use_scale_shift_norm=False,
resblock_updown=False, self.inplanes = 64
use_new_attention_order=False, self.dilation = 1
pool="adaptive", if replace_stride_with_dilation is None:
): # each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.depth = depth
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
f=128
if self.depth > 2:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
f=256
if self.depth > 3:
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
f=512
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(f * block.expansion, output_dim)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
if self.depth > 2:
x = self.layer3(x)
if self.depth > 3:
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
class UnetWithBuiltInLatentEncoder(nn.Module):
def __init__(self, **kwargs):
depth_map = {
256: 4,
128: 3,
64: 2
}
super().__init__() super().__init__()
self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']])
self.lq_jitter = ColorJitter(.05, .05, .05, .05)
self.unet = SuperResModel(**kwargs)
if num_heads_upsample == -1: def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs):
num_heads_upsample = num_heads latent = self.encoder(alt_hq)
low_res = self.lq_jitter((low_res+1)/2)*2-1
return self.unet(x, timesteps, latent, low_res, **kwargs)
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)
@register_model @register_model
def register_unet_diffusion(opt_net, opt): def register_unet_diffusion_latent_guide(opt_net, opt):
return SuperResModel(**opt_net['args']) return UnetWithBuiltInLatentEncoder(**opt_net['args'])
if __name__ == '__main__': if __name__ == '__main__':
attention_ds = [] attention_ds = []
for res in "16,8".split(","): for res in "16,8".split(","):
attention_ds.append(128 // int(res)) attention_ds.append(128 // int(res))
srm = SuperResModel(image_size=128, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4, srm = UnetWithBuiltInLatentEncoder(image_size=64, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4,
num_heads_upsample=-1, use_scale_shift_norm=True) num_heads_upsample=-1, use_scale_shift_norm=True)
x = torch.randn(1,3,128,128) x = torch.randn(1,3,64,64)
alt_x = torch.randn(1,3,64,64)
l = torch.randn(1,3,32,32) l = torch.randn(1,3,32,32)
ts = torch.LongTensor([555]) ts = torch.LongTensor([555])
y = srm(x, ts, low_res=l) y = srm(x, ts, alt_x, low_res=l)
print(y.shape, y.mean(), y.std(), y.min(), y.max()) print(y.shape, y.mean(), y.std(), y.min(), y.max())

View File

@ -299,7 +299,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_latent_unet_diffusion_sm.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()