Add unet with latent guide
This is a diffusion network that uses both a LQ image and a reference sample HQ image that is compressed into a latent vector to perform upsampling The hope is that we can steer the upsampling network with sample images.
This commit is contained in:
parent
0ded106562
commit
46e9f62be0
|
@ -208,21 +208,27 @@ class ImageFolderDataset:
|
|||
# This assumes the output format generated by the tiled image generation scripts included with DLAS. Specifically,
|
||||
# all image read by this dataset are assumed to be in subfolders with other tiles from the same source image. When
|
||||
# this option is set, another random image from the same folder is selected and returned as the alt image.
|
||||
sel_path = self.image_paths[item]
|
||||
other_images = random.shuffle(os.listdir(sel_path))
|
||||
sel_path = os.path.dirname(self.image_paths[item])
|
||||
other_images = os.listdir(sel_path)
|
||||
# Assume that the directory contains at least <image>, <ref.jpg>, <centers.pt>
|
||||
if len(other_images) <= 3:
|
||||
alt_hq = hq # This is a fallback in case an alt image can't be found.
|
||||
else:
|
||||
for oi in other_images:
|
||||
if oi == sel_path or 'ref.' in oi or 'centers.pt' in oi:
|
||||
continue
|
||||
alt_hq = util.read_img(None, oi, rgb=True)
|
||||
alt_hs = self.resize_hq([alt_hq])
|
||||
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
|
||||
out_dict['has_alt'] = True
|
||||
try:
|
||||
if len(other_images) <= 3:
|
||||
alt_hq = hq # This is a fallback in case an alt image can't be found.
|
||||
else:
|
||||
random.shuffle(other_images)
|
||||
for oi in other_images:
|
||||
if oi == os.path.basename(self.image_paths[item]) or 'ref.' in oi or 'centers.pt' in oi:
|
||||
continue
|
||||
alt_hq = util.read_img(None, os.path.join(sel_path, oi), rgb=True)
|
||||
alt_hs = self.resize_hq([alt_hq])
|
||||
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
|
||||
except:
|
||||
alt_hq = hq
|
||||
print(f"Error with {self.image_paths[item]}")
|
||||
out_dict['has_alt'] = True
|
||||
out_dict['alt_hq'] = alt_hq
|
||||
|
||||
|
||||
if not self.skip_lq:
|
||||
lqs, ent = self.synthesize_lq(for_lq)
|
||||
ls = lqs[0]
|
||||
|
@ -263,13 +269,14 @@ if __name__ == '__main__':
|
|||
'scale': 2,
|
||||
'corrupt_before_downsize': True,
|
||||
'fetch_alt_image': False,
|
||||
'fetch_alt_tiled_image': True,
|
||||
'disable_flip': True,
|
||||
'fixed_corruptions': [ 'jpeg-medium' ],
|
||||
'num_corrupts_per_image': 0,
|
||||
'corruption_blur_scale': 0
|
||||
}
|
||||
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64)
|
||||
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64)
|
||||
import os
|
||||
output_path = 'F:\\tmp'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
import math
|
||||
from typing import Union, Type, Callable, Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -8,6 +9,10 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision # For debugging, not actually used.
|
||||
from kornia.augmentation import ColorJitter
|
||||
from torch import Tensor
|
||||
from torchvision.models import resnet50
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
|
||||
|
||||
from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32
|
||||
from models.diffusion.nn import (
|
||||
|
@ -677,231 +682,159 @@ class SuperResModel(UNetModel):
|
|||
corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
|
||||
upsampled = torch.cat([upsampled, corruption_factor], dim=1)
|
||||
x = th.cat([x, upsampled], dim=1)
|
||||
res = super().forward(x, timesteps, latent, **kwargs)
|
||||
res = super().forward(x, latent, timesteps, **kwargs)
|
||||
return res
|
||||
|
||||
|
||||
class EncoderUNetModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
|
||||
For usage, see UNet.
|
||||
"""
|
||||
class ResNetEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
):
|
||||
block: Type[Union[BasicBlock, Bottleneck]] = Bottleneck,
|
||||
layers: List[int] = [3, 4, 6, 3],
|
||||
depth: int = 4,
|
||||
output_dim: int = 512,
|
||||
zero_init_residual: bool = False,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(ResNetEncoder, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.depth = depth
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
f=128
|
||||
if self.depth > 2:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
f=256
|
||||
if self.depth > 3:
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
f=512
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(f * block.expansion, output_dim)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||
|
||||
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
if self.depth > 2:
|
||||
x = self.layer3(x)
|
||||
if self.depth > 3:
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
class UnetWithBuiltInLatentEncoder(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
depth_map = {
|
||||
256: 4,
|
||||
128: 3,
|
||||
64: 2
|
||||
}
|
||||
super().__init__()
|
||||
self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']])
|
||||
self.lq_jitter = ColorJitter(.05, .05, .05, .05)
|
||||
self.unet = SuperResModel(**kwargs)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs):
|
||||
latent = self.encoder(alt_hq)
|
||||
low_res = self.lq_jitter((low_res+1)/2)*2-1
|
||||
return self.unet(x, timesteps, latent, low_res, **kwargs)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == "adaptive":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == "attention":
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
AttentionPool2d(
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
nn.SiLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
return self.out(h)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
@register_model
|
||||
def register_unet_diffusion(opt_net, opt):
|
||||
return SuperResModel(**opt_net['args'])
|
||||
def register_unet_diffusion_latent_guide(opt_net, opt):
|
||||
return UnetWithBuiltInLatentEncoder(**opt_net['args'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
attention_ds = []
|
||||
for res in "16,8".split(","):
|
||||
attention_ds.append(128 // int(res))
|
||||
srm = SuperResModel(image_size=128, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4,
|
||||
srm = UnetWithBuiltInLatentEncoder(image_size=64, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4,
|
||||
num_heads_upsample=-1, use_scale_shift_norm=True)
|
||||
x = torch.randn(1,3,128,128)
|
||||
x = torch.randn(1,3,64,64)
|
||||
alt_x = torch.randn(1,3,64,64)
|
||||
l = torch.randn(1,3,32,32)
|
||||
ts = torch.LongTensor([555])
|
||||
y = srm(x, ts, low_res=l)
|
||||
y = srm(x, ts, alt_x, low_res=l)
|
||||
print(y.shape, y.mean(), y.std(), y.min(), y.max())
|
||||
|
|
|
@ -299,7 +299,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_latent_unet_diffusion_sm.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user