diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 6ddb4e8b..157b4007 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -208,21 +208,27 @@ class ImageFolderDataset: # This assumes the output format generated by the tiled image generation scripts included with DLAS. Specifically, # all image read by this dataset are assumed to be in subfolders with other tiles from the same source image. When # this option is set, another random image from the same folder is selected and returned as the alt image. - sel_path = self.image_paths[item] - other_images = random.shuffle(os.listdir(sel_path)) + sel_path = os.path.dirname(self.image_paths[item]) + other_images = os.listdir(sel_path) # Assume that the directory contains at least , , - if len(other_images) <= 3: - alt_hq = hq # This is a fallback in case an alt image can't be found. - else: - for oi in other_images: - if oi == sel_path or 'ref.' in oi or 'centers.pt' in oi: - continue - alt_hq = util.read_img(None, oi, rgb=True) - alt_hs = self.resize_hq([alt_hq]) - alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float() - out_dict['has_alt'] = True + try: + if len(other_images) <= 3: + alt_hq = hq # This is a fallback in case an alt image can't be found. + else: + random.shuffle(other_images) + for oi in other_images: + if oi == os.path.basename(self.image_paths[item]) or 'ref.' in oi or 'centers.pt' in oi: + continue + alt_hq = util.read_img(None, os.path.join(sel_path, oi), rgb=True) + alt_hs = self.resize_hq([alt_hq]) + alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float() + except: + alt_hq = hq + print(f"Error with {self.image_paths[item]}") + out_dict['has_alt'] = True out_dict['alt_hq'] = alt_hq + if not self.skip_lq: lqs, ent = self.synthesize_lq(for_lq) ls = lqs[0] @@ -263,13 +269,14 @@ if __name__ == '__main__': 'scale': 2, 'corrupt_before_downsize': True, 'fetch_alt_image': False, + 'fetch_alt_tiled_image': True, 'disable_flip': True, 'fixed_corruptions': [ 'jpeg-medium' ], 'num_corrupts_per_image': 0, 'corruption_blur_scale': 0 } - ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=4, batch_size=64) + ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64) import os output_path = 'F:\\tmp' os.makedirs(output_path, exist_ok=True) diff --git a/codes/models/diffusion/unet_latent_guide.py b/codes/models/diffusion/unet_latent_guide.py index 79aae0f5..4c1881ac 100644 --- a/codes/models/diffusion/unet_latent_guide.py +++ b/codes/models/diffusion/unet_latent_guide.py @@ -1,6 +1,7 @@ from abc import abstractmethod import math +from typing import Union, Type, Callable, Optional, List import numpy as np import torch @@ -8,6 +9,10 @@ import torch as th import torch.nn as nn import torch.nn.functional as F import torchvision # For debugging, not actually used. +from kornia.augmentation import ColorJitter +from torch import Tensor +from torchvision.models import resnet50 +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32 from models.diffusion.nn import ( @@ -677,231 +682,159 @@ class SuperResModel(UNetModel): corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) upsampled = torch.cat([upsampled, corruption_factor], dim=1) x = th.cat([x, upsampled], dim=1) - res = super().forward(x, timesteps, latent, **kwargs) + res = super().forward(x, latent, timesteps, **kwargs) return res -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - - For usage, see UNet. - """ +class ResNetEncoder(nn.Module): def __init__( self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - ): + block: Type[Union[BasicBlock, Bottleneck]] = Bottleneck, + layers: List[int] = [3, 4, 6, 3], + depth: int = 4, + output_dim: int = 512, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNetEncoder, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.depth = depth + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + f=128 + if self.depth > 2: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + f=256 + if self.depth > 3: + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + f=512 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(f * block.expansion, output_dim) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + if self.depth > 2: + x = self.layer3(x) + if self.depth > 3: + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +class UnetWithBuiltInLatentEncoder(nn.Module): + def __init__(self, **kwargs): + depth_map = { + 256: 4, + 128: 3, + 64: 2 + } super().__init__() + self.encoder = ResNetEncoder(depth=depth_map[kwargs['image_size']]) + self.lq_jitter = ColorJitter(.05, .05, .05, .05) + self.unet = SuperResModel(**kwargs) - if num_heads_upsample == -1: - num_heads_upsample = num_heads + def forward(self, x, timesteps, alt_hq, low_res=None, **kwargs): + latent = self.encoder(alt_hq) + low_res = self.lq_jitter((low_res+1)/2)*2-1 + return self.unet(x, timesteps, latent, low_res, **kwargs) - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - results = [] - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(x.dtype) - return self.out(h) @register_model -def register_unet_diffusion(opt_net, opt): - return SuperResModel(**opt_net['args']) +def register_unet_diffusion_latent_guide(opt_net, opt): + return UnetWithBuiltInLatentEncoder(**opt_net['args']) + if __name__ == '__main__': attention_ds = [] for res in "16,8".split(","): attention_ds.append(128 // int(res)) - srm = SuperResModel(image_size=128, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4, + srm = UnetWithBuiltInLatentEncoder(image_size=64, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4, num_heads_upsample=-1, use_scale_shift_norm=True) - x = torch.randn(1,3,128,128) + x = torch.randn(1,3,64,64) + alt_x = torch.randn(1,3,64,64) l = torch.randn(1,3,32,32) ts = torch.LongTensor([555]) - y = srm(x, ts, low_res=l) + y = srm(x, ts, alt_x, low_res=l) print(y.shape, y.mean(), y.std(), y.min(), y.max()) diff --git a/codes/train.py b/codes/train.py index daf49b0d..df293997 100644 --- a/codes/train.py +++ b/codes/train.py @@ -299,7 +299,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_latent_unet_diffusion_sm.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()