forked from mrq/DL-Art-School
Clean up codebase
Remove stuff that I'm likely not going to use again (or generally failed experiments)
This commit is contained in:
parent
4d1a42e944
commit
55b58fb67f
|
@ -1,507 +0,0 @@
|
||||||
import math
|
|
||||||
import copy
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from functools import wraps, partial
|
|
||||||
from math import floor
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from torch import nn, einsum
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from kornia import augmentation as augs
|
|
||||||
from kornia import filters, color
|
|
||||||
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
# helper functions
|
|
||||||
from trainer.networks import register_model, create_model
|
|
||||||
|
|
||||||
|
|
||||||
def identity(t):
|
|
||||||
return t
|
|
||||||
|
|
||||||
def default(val, def_val):
|
|
||||||
return def_val if val is None else val
|
|
||||||
|
|
||||||
def rand_true(prob):
|
|
||||||
return random.random() < prob
|
|
||||||
|
|
||||||
def singleton(cache_key):
|
|
||||||
def inner_fn(fn):
|
|
||||||
@wraps(fn)
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
instance = getattr(self, cache_key)
|
|
||||||
if instance is not None:
|
|
||||||
return instance
|
|
||||||
|
|
||||||
instance = fn(self, *args, **kwargs)
|
|
||||||
setattr(self, cache_key, instance)
|
|
||||||
return instance
|
|
||||||
return wrapper
|
|
||||||
return inner_fn
|
|
||||||
|
|
||||||
def get_module_device(module):
|
|
||||||
return next(module.parameters()).device
|
|
||||||
|
|
||||||
def set_requires_grad(model, val):
|
|
||||||
for p in model.parameters():
|
|
||||||
p.requires_grad = val
|
|
||||||
|
|
||||||
def cutout_coordinates(image, ratio_range = (0.6, 0.8)):
|
|
||||||
_, _, orig_h, orig_w = image.shape
|
|
||||||
|
|
||||||
ratio_lo, ratio_hi = ratio_range
|
|
||||||
random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo)
|
|
||||||
w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h)
|
|
||||||
coor_x = floor((orig_w - w) * random.random())
|
|
||||||
coor_y = floor((orig_h - h) * random.random())
|
|
||||||
return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio
|
|
||||||
|
|
||||||
def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
|
|
||||||
shape = image.shape
|
|
||||||
output_size = default(output_size, shape[2:])
|
|
||||||
(y0, y1), (x0, x1) = coordinates
|
|
||||||
cutout_image = image[:, :, y0:y1, x0:x1]
|
|
||||||
return F.interpolate(cutout_image, size = output_size, mode = mode)
|
|
||||||
|
|
||||||
def scale_coords(coords, scale):
|
|
||||||
output = [[0,0],[0,0]]
|
|
||||||
for j in range(2):
|
|
||||||
for k in range(2):
|
|
||||||
output[j][k] = int(coords[j][k] / scale)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'):
|
|
||||||
blank = torch.zeros_like(image)
|
|
||||||
coordinates = scale_coords(coordinates, scale_reduction)
|
|
||||||
(y0, y1), (x0, x1) = coordinates
|
|
||||||
orig_cutout_shape = (y1-y0, x1-x0)
|
|
||||||
if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode)
|
|
||||||
blank[:,:,y0:y1,x0:x1] = un_resized_img
|
|
||||||
return blank
|
|
||||||
|
|
||||||
def compute_shared_coords(coords1, coords2, scale_reduction):
|
|
||||||
(y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction)
|
|
||||||
(y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction)
|
|
||||||
shared = ((max(y1_t, y2_t), min(y1_b, y2_b)),
|
|
||||||
(max(x1_l, x2_l), min(x1_r, x2_r)))
|
|
||||||
for s in shared:
|
|
||||||
if s == 0:
|
|
||||||
return None
|
|
||||||
return shared
|
|
||||||
|
|
||||||
def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode):
|
|
||||||
# Unflip the pixel projections
|
|
||||||
proj_pixel_one = flip_image_one_fn(proj_pixel_one)
|
|
||||||
proj_pixel_two = flip_image_two_fn(proj_pixel_two)
|
|
||||||
|
|
||||||
# Undo the cutout and resize, taking into account the scale reduction applied by the encoder.
|
|
||||||
scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1]
|
|
||||||
proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction,
|
|
||||||
mode=interp_mode)
|
|
||||||
proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction,
|
|
||||||
mode=interp_mode)
|
|
||||||
if proj_pixel_one is None or proj_pixel_two is None:
|
|
||||||
print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Compute the shared coordinates for the two cutouts:
|
|
||||||
shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction)
|
|
||||||
if shared_coords is None:
|
|
||||||
print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..")
|
|
||||||
return None
|
|
||||||
(yt, yb), (xl, xr) = shared_coords
|
|
||||||
|
|
||||||
return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr]
|
|
||||||
|
|
||||||
# augmentation utils
|
|
||||||
|
|
||||||
class RandomApply(nn.Module):
|
|
||||||
def __init__(self, fn, p):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
self.p = p
|
|
||||||
def forward(self, x):
|
|
||||||
if random.random() > self.p:
|
|
||||||
return x
|
|
||||||
return self.fn(x)
|
|
||||||
|
|
||||||
# exponential moving average
|
|
||||||
|
|
||||||
class EMA():
|
|
||||||
def __init__(self, beta):
|
|
||||||
super().__init__()
|
|
||||||
self.beta = beta
|
|
||||||
|
|
||||||
def update_average(self, old, new):
|
|
||||||
if old is None:
|
|
||||||
return new
|
|
||||||
return old * self.beta + (1 - self.beta) * new
|
|
||||||
|
|
||||||
def update_moving_average(ema_updater, ma_model, current_model):
|
|
||||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
|
||||||
old_weight, up_weight = ma_params.data, current_params.data
|
|
||||||
ma_params.data = ema_updater.update_average(old_weight, up_weight)
|
|
||||||
|
|
||||||
# loss fn
|
|
||||||
|
|
||||||
def loss_fn(x, y):
|
|
||||||
x = F.normalize(x, dim=-1, p=2)
|
|
||||||
y = F.normalize(y, dim=-1, p=2)
|
|
||||||
return 2 - 2 * (x * y).sum(dim=-1)
|
|
||||||
|
|
||||||
# classes
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
|
||||||
def __init__(self, chan, chan_out = 256, inner_dim = 2048):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(chan, inner_dim),
|
|
||||||
nn.BatchNorm1d(inner_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(inner_dim, chan_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
class ConvMLP(nn.Module):
|
|
||||||
def __init__(self, chan, chan_out = 256, inner_dim = 2048):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Conv2d(chan, inner_dim, 1),
|
|
||||||
nn.BatchNorm2d(inner_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(inner_dim, chan_out, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
class PPM(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
chan,
|
|
||||||
num_layers = 1,
|
|
||||||
gamma = 2):
|
|
||||||
super().__init__()
|
|
||||||
self.gamma = gamma
|
|
||||||
|
|
||||||
if num_layers == 0:
|
|
||||||
self.transform_net = nn.Identity()
|
|
||||||
elif num_layers == 1:
|
|
||||||
self.transform_net = nn.Conv2d(chan, chan, 1)
|
|
||||||
elif num_layers == 2:
|
|
||||||
self.transform_net = nn.Sequential(
|
|
||||||
nn.Conv2d(chan, chan, 1),
|
|
||||||
nn.BatchNorm2d(chan),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(chan, chan, 1)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError('num_layers must be one of 0, 1, or 2')
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
xi = x[:, :, :, :, None, None]
|
|
||||||
xj = x[:, :, None, None, :, :]
|
|
||||||
similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma
|
|
||||||
|
|
||||||
transform_out = self.transform_net(x)
|
|
||||||
out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
# a wrapper class for the base neural network
|
|
||||||
# will manage the interception of the hidden layer output
|
|
||||||
# and pipe it into the projecter and predictor nets
|
|
||||||
|
|
||||||
class NetWrapper(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
net,
|
|
||||||
instance_projection_size,
|
|
||||||
instance_projection_hidden_size,
|
|
||||||
pix_projection_size,
|
|
||||||
pix_projection_hidden_size,
|
|
||||||
layer_pixel = -2,
|
|
||||||
layer_instance = -2
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.net = net
|
|
||||||
self.layer_pixel = layer_pixel
|
|
||||||
self.layer_instance = layer_instance
|
|
||||||
|
|
||||||
self.pixel_projector = None
|
|
||||||
self.instance_projector = None
|
|
||||||
|
|
||||||
self.instance_projection_size = instance_projection_size
|
|
||||||
self.instance_projection_hidden_size = instance_projection_hidden_size
|
|
||||||
self.pix_projection_size = pix_projection_size
|
|
||||||
self.pix_projection_hidden_size = pix_projection_hidden_size
|
|
||||||
|
|
||||||
self.hidden_pixel = None
|
|
||||||
self.hidden_instance = None
|
|
||||||
self.hook_registered = False
|
|
||||||
|
|
||||||
def _find_layer(self, layer_id):
|
|
||||||
if type(layer_id) == str:
|
|
||||||
modules = dict([*self.net.named_modules()])
|
|
||||||
return modules.get(layer_id, None)
|
|
||||||
elif type(layer_id) == int:
|
|
||||||
children = [*self.net.children()]
|
|
||||||
return children[layer_id]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _hook(self, attr_name, _, __, output):
|
|
||||||
setattr(self, attr_name, output)
|
|
||||||
|
|
||||||
def _register_hook(self):
|
|
||||||
pixel_layer = self._find_layer(self.layer_pixel)
|
|
||||||
instance_layer = self._find_layer(self.layer_instance)
|
|
||||||
|
|
||||||
assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found'
|
|
||||||
assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found'
|
|
||||||
|
|
||||||
pixel_layer.register_forward_hook(partial(self._hook, 'hidden_pixel'))
|
|
||||||
instance_layer.register_forward_hook(partial(self._hook, 'hidden_instance'))
|
|
||||||
self.hook_registered = True
|
|
||||||
|
|
||||||
@singleton('pixel_projector')
|
|
||||||
def _get_pixel_projector(self, hidden):
|
|
||||||
_, dim, *_ = hidden.shape
|
|
||||||
projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size)
|
|
||||||
return projector.to(hidden)
|
|
||||||
|
|
||||||
@singleton('instance_projector')
|
|
||||||
def _get_instance_projector(self, hidden):
|
|
||||||
_, dim = hidden.shape
|
|
||||||
projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size)
|
|
||||||
return projector.to(hidden)
|
|
||||||
|
|
||||||
def get_representation(self, x):
|
|
||||||
if not self.hook_registered:
|
|
||||||
self._register_hook()
|
|
||||||
|
|
||||||
_ = self.net(x)
|
|
||||||
hidden_pixel = self.hidden_pixel
|
|
||||||
hidden_instance = self.hidden_instance
|
|
||||||
self.hidden_pixel = None
|
|
||||||
self.hidden_instance = None
|
|
||||||
assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output'
|
|
||||||
assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output'
|
|
||||||
return hidden_pixel, hidden_instance
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pixel_representation, instance_representation = self.get_representation(x)
|
|
||||||
instance_representation = instance_representation.flatten(1)
|
|
||||||
|
|
||||||
pixel_projector = self._get_pixel_projector(pixel_representation)
|
|
||||||
instance_projector = self._get_instance_projector(instance_representation)
|
|
||||||
|
|
||||||
pixel_projection = pixel_projector(pixel_representation)
|
|
||||||
instance_projection = instance_projector(instance_representation)
|
|
||||||
return pixel_projection, instance_projection
|
|
||||||
|
|
||||||
# main class
|
|
||||||
class PixelCL(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
net,
|
|
||||||
image_size,
|
|
||||||
hidden_layer_pixel = -2,
|
|
||||||
hidden_layer_instance = -2,
|
|
||||||
instance_projection_size = 256,
|
|
||||||
instance_projection_hidden_size = 2048,
|
|
||||||
pix_projection_size = 256,
|
|
||||||
pix_projection_hidden_size = 2048,
|
|
||||||
augment_fn = None,
|
|
||||||
augment_fn2 = None,
|
|
||||||
prob_rand_hflip = 0.25,
|
|
||||||
moving_average_decay = 0.99,
|
|
||||||
ppm_num_layers = 1,
|
|
||||||
ppm_gamma = 2,
|
|
||||||
distance_thres = 0.7,
|
|
||||||
similarity_temperature = 0.3,
|
|
||||||
cutout_ratio_range = (0.6, 0.8),
|
|
||||||
cutout_interpolate_mode = 'nearest',
|
|
||||||
coord_cutout_interpolate_mode = 'bilinear',
|
|
||||||
max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
DEFAULT_AUG = nn.Sequential(
|
|
||||||
RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8),
|
|
||||||
augs.RandomGrayscale(p=0.2),
|
|
||||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
|
||||||
augs.RandomSolarize(p=0.5),
|
|
||||||
# Normalize left out because it should be done at the model level.
|
|
||||||
)
|
|
||||||
|
|
||||||
self.augment1 = default(augment_fn, DEFAULT_AUG)
|
|
||||||
self.augment2 = default(augment_fn2, self.augment1)
|
|
||||||
self.prob_rand_hflip = prob_rand_hflip
|
|
||||||
|
|
||||||
self.online_encoder = NetWrapper(
|
|
||||||
net = net,
|
|
||||||
instance_projection_size = instance_projection_size,
|
|
||||||
instance_projection_hidden_size = instance_projection_hidden_size,
|
|
||||||
pix_projection_size = pix_projection_size,
|
|
||||||
pix_projection_hidden_size = pix_projection_hidden_size,
|
|
||||||
layer_pixel = hidden_layer_pixel,
|
|
||||||
layer_instance = hidden_layer_instance
|
|
||||||
)
|
|
||||||
|
|
||||||
self.target_encoder = None
|
|
||||||
self.target_ema_updater = EMA(moving_average_decay)
|
|
||||||
|
|
||||||
self.distance_thres = distance_thres
|
|
||||||
self.similarity_temperature = similarity_temperature
|
|
||||||
|
|
||||||
# This requirement is due to the way that these are processed, not a hard requirement.
|
|
||||||
assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
|
|
||||||
self.max_latent_dim = max_latent_dim
|
|
||||||
|
|
||||||
self.propagate_pixels = PPM(
|
|
||||||
chan = pix_projection_size,
|
|
||||||
num_layers = ppm_num_layers,
|
|
||||||
gamma = ppm_gamma
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cutout_ratio_range = cutout_ratio_range
|
|
||||||
self.cutout_interpolate_mode = cutout_interpolate_mode
|
|
||||||
self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
|
|
||||||
|
|
||||||
# instance level predictor
|
|
||||||
self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size)
|
|
||||||
|
|
||||||
# get device of network and make wrapper same device
|
|
||||||
device = get_module_device(net)
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
# send a mock image tensor to instantiate singleton parameters
|
|
||||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
|
||||||
|
|
||||||
@singleton('target_encoder')
|
|
||||||
def _get_target_encoder(self):
|
|
||||||
target_encoder = copy.deepcopy(self.online_encoder)
|
|
||||||
set_requires_grad(target_encoder, False)
|
|
||||||
return target_encoder
|
|
||||||
|
|
||||||
def reset_moving_average(self):
|
|
||||||
del self.target_encoder
|
|
||||||
self.target_encoder = None
|
|
||||||
|
|
||||||
def update_moving_average(self):
|
|
||||||
assert self.target_encoder is not None, 'target encoder has not been created yet'
|
|
||||||
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip
|
|
||||||
|
|
||||||
rand_flip_fn = lambda t: torch.flip(t, dims = (-1,))
|
|
||||||
|
|
||||||
flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip)
|
|
||||||
flip_image_one_fn = rand_flip_fn if flip_image_one else identity
|
|
||||||
flip_image_two_fn = rand_flip_fn if flip_image_two else identity
|
|
||||||
|
|
||||||
cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range)
|
|
||||||
cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range)
|
|
||||||
|
|
||||||
image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode)
|
|
||||||
image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode)
|
|
||||||
|
|
||||||
image_one_cutout = flip_image_one_fn(image_one_cutout)
|
|
||||||
image_two_cutout = flip_image_two_fn(image_two_cutout)
|
|
||||||
|
|
||||||
image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout)
|
|
||||||
|
|
||||||
self.aug1 = image_one_cutout.detach().clone()
|
|
||||||
self.aug2 = image_two_cutout.detach().clone()
|
|
||||||
|
|
||||||
proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
|
|
||||||
proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
|
|
||||||
|
|
||||||
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
|
|
||||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
|
||||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
|
||||||
if proj_pixel_one is None or proj_pixel_two is None:
|
|
||||||
positive_pixel_pairs = 0
|
|
||||||
else:
|
|
||||||
positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
target_encoder = self._get_target_encoder()
|
|
||||||
target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
|
|
||||||
target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
|
|
||||||
target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one,
|
|
||||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
|
||||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
|
||||||
|
|
||||||
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
|
|
||||||
b, c, pp_h, pp_w = proj_pixel_one.shape
|
|
||||||
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
|
|
||||||
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
|
|
||||||
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
|
|
||||||
extracted = []
|
|
||||||
for l in latents:
|
|
||||||
l = l.reshape(b, c, pp_h * pp_w)
|
|
||||||
l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
|
|
||||||
# For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square".
|
|
||||||
# Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards
|
|
||||||
# to the original image structure.
|
|
||||||
sqdim = int(math.sqrt(self.max_latent_dim))
|
|
||||||
extracted.append(l.reshape(b, c, sqdim, sqdim))
|
|
||||||
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
|
|
||||||
|
|
||||||
# flatten all the pixel projections
|
|
||||||
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
|
|
||||||
target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
|
|
||||||
|
|
||||||
# get instance level loss
|
|
||||||
pred_instance_one = self.online_predictor(proj_instance_one)
|
|
||||||
pred_instance_two = self.online_predictor(proj_instance_two)
|
|
||||||
loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
|
|
||||||
loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
|
|
||||||
instance_loss = (loss_instance_one + loss_instance_two).mean()
|
|
||||||
|
|
||||||
if positive_pixel_pairs == 0:
|
|
||||||
return instance_loss, 0
|
|
||||||
|
|
||||||
# calculate pix pro loss
|
|
||||||
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
|
|
||||||
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
|
|
||||||
|
|
||||||
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
|
|
||||||
|
|
||||||
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
|
|
||||||
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
|
|
||||||
|
|
||||||
loss_pixpro_one_two = - propagated_similarity_one_two.mean()
|
|
||||||
loss_pixpro_two_one = - propagated_similarity_two_one.mean()
|
|
||||||
|
|
||||||
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
|
|
||||||
|
|
||||||
return instance_loss, pix_loss, positive_pixel_pairs
|
|
||||||
|
|
||||||
# Allows visualizing what the augmentor is up to.
|
|
||||||
def visual_dbg(self, step, path):
|
|
||||||
if not hasattr(self, 'aug1'):
|
|
||||||
return
|
|
||||||
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
|
|
||||||
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_pixel_contrastive_learner(opt_net, opt):
|
|
||||||
subnet = create_model(opt, opt_net['subnet'])
|
|
||||||
kwargs = opt_net['kwargs']
|
|
||||||
if 'subnet_pretrain_path' in opt_net.keys():
|
|
||||||
sd = torch.load(opt_net['subnet_pretrain_path'])
|
|
||||||
subnet.load_state_dict(sd, strict=False)
|
|
||||||
return PixelCL(subnet, **kwargs)
|
|
|
@ -1,152 +0,0 @@
|
||||||
# Resnet implementation that adds a u-net style up-conversion component to output values at a
|
|
||||||
# specified pixel density.
|
|
||||||
#
|
|
||||||
# The downsampling part of the network is compatible with the built-in torch resnet for use in
|
|
||||||
# transfer learning.
|
|
||||||
#
|
|
||||||
# Only resnet50 currently supported.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
|
||||||
from torchvision.models.utils import load_state_dict_from_url
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import checkpoint, opt_get
|
|
||||||
|
|
||||||
|
|
||||||
class ReverseBottleneck(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, groups=1, passthrough=False,
|
|
||||||
base_width=64, dilation=1, norm_layer=None):
|
|
||||||
super().__init__()
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
width = int(planes * (base_width / 64.)) * groups
|
|
||||||
self.passthrough = passthrough
|
|
||||||
if passthrough:
|
|
||||||
self.integrate = conv1x1(inplanes*2, inplanes)
|
|
||||||
self.bn_integrate = norm_layer(inplanes)
|
|
||||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
||||||
self.conv1 = conv1x1(inplanes, width)
|
|
||||||
self.bn1 = norm_layer(width)
|
|
||||||
self.conv2 = conv3x3(width, width, groups, dilation)
|
|
||||||
self.bn2 = norm_layer(width)
|
|
||||||
self.residual_upsample = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
||||||
conv1x1(width, width),
|
|
||||||
norm_layer(width),
|
|
||||||
)
|
|
||||||
self.conv3 = conv1x1(width, planes)
|
|
||||||
self.bn3 = norm_layer(planes)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.upsample = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
||||||
conv1x1(inplanes, planes),
|
|
||||||
norm_layer(planes),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, passthrough=None):
|
|
||||||
if self.passthrough:
|
|
||||||
x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1)))
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.residual_upsample(out)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
identity = self.upsample(x)
|
|
||||||
|
|
||||||
out = out + identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class UResNet50(torchvision.models.resnet.ResNet):
|
|
||||||
|
|
||||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
|
||||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
||||||
norm_layer=None, out_dim=128):
|
|
||||||
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
|
||||||
replace_stride_with_dilation, norm_layer)
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
'''
|
|
||||||
# For reference:
|
|
||||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
||||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[0])
|
|
||||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[1])
|
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[2])
|
|
||||||
'''
|
|
||||||
uplayers = []
|
|
||||||
inplanes = 2048
|
|
||||||
first = True
|
|
||||||
for i in range(2):
|
|
||||||
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
|
|
||||||
inplanes = inplanes // 2
|
|
||||||
first = False
|
|
||||||
self.uplayers = nn.ModuleList(uplayers)
|
|
||||||
self.tail = nn.Sequential(conv1x1(1024, 512),
|
|
||||||
norm_layer(512),
|
|
||||||
nn.ReLU(),
|
|
||||||
conv3x3(512, 512),
|
|
||||||
norm_layer(512),
|
|
||||||
nn.ReLU(),
|
|
||||||
conv1x1(512, out_dim))
|
|
||||||
|
|
||||||
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
|
||||||
|
|
||||||
|
|
||||||
def _forward_impl(self, x):
|
|
||||||
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl,
|
|
||||||
# except using checkpoints on the body conv layers.
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.maxpool(x)
|
|
||||||
|
|
||||||
x1 = checkpoint(self.layer1, x)
|
|
||||||
x2 = checkpoint(self.layer2, x1)
|
|
||||||
x3 = checkpoint(self.layer3, x2)
|
|
||||||
x4 = checkpoint(self.layer4, x3)
|
|
||||||
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
|
|
||||||
|
|
||||||
x = checkpoint(self.uplayers[0], x4)
|
|
||||||
x = checkpoint(self.uplayers[1], x, x3)
|
|
||||||
#x = checkpoint(self.uplayers[2], x, x2)
|
|
||||||
#x = checkpoint(self.uplayers[3], x, x1)
|
|
||||||
|
|
||||||
return checkpoint(self.tail, torch.cat([x, x2], dim=1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self._forward_impl(x)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_u_resnet50(opt_net, opt):
|
|
||||||
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
|
||||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
|
||||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
model = UResNet50(Bottleneck, [3,4,6,3])
|
|
||||||
samp = torch.rand(1,3,224,224)
|
|
||||||
model(samp)
|
|
||||||
# For pixpro: attach to "tail.3"
|
|
|
@ -1,87 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
|
||||||
from torchvision.models.utils import load_state_dict_from_url
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
from models.arch_util import ConvBnRelu
|
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import checkpoint, opt_get
|
|
||||||
|
|
||||||
|
|
||||||
class UResNet50_2(torchvision.models.resnet.ResNet):
|
|
||||||
|
|
||||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
|
||||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
||||||
norm_layer=None, out_dim=128):
|
|
||||||
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
|
||||||
replace_stride_with_dilation, norm_layer)
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
self.level_conv = ConvBnRelu(3, 64)
|
|
||||||
'''
|
|
||||||
# For reference:
|
|
||||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
||||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[0])
|
|
||||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[1])
|
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[2])
|
|
||||||
'''
|
|
||||||
uplayers = []
|
|
||||||
inplanes = 2048
|
|
||||||
first = True
|
|
||||||
div = [2,2,2,4,1]
|
|
||||||
for i in range(5):
|
|
||||||
uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first))
|
|
||||||
inplanes = inplanes // div[i]
|
|
||||||
first = False
|
|
||||||
self.uplayers = nn.ModuleList(uplayers)
|
|
||||||
self.tail = nn.Sequential(conv3x3(128, 64),
|
|
||||||
norm_layer(64),
|
|
||||||
nn.ReLU(),
|
|
||||||
conv1x1(64, out_dim))
|
|
||||||
|
|
||||||
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
|
||||||
|
|
||||||
|
|
||||||
def _forward_impl(self, x):
|
|
||||||
level = self.level_conv(x)
|
|
||||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
|
||||||
x = self.maxpool(x0)
|
|
||||||
|
|
||||||
x1 = checkpoint(self.layer1, x)
|
|
||||||
x2 = checkpoint(self.layer2, x1)
|
|
||||||
x3 = checkpoint(self.layer3, x2)
|
|
||||||
x4 = checkpoint(self.layer4, x3)
|
|
||||||
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
|
|
||||||
|
|
||||||
x = checkpoint(self.uplayers[0], x4)
|
|
||||||
x = checkpoint(self.uplayers[1], x, x3)
|
|
||||||
x = checkpoint(self.uplayers[2], x, x2)
|
|
||||||
x = checkpoint(self.uplayers[3], x, x1)
|
|
||||||
x = checkpoint(self.uplayers[4], x, x0)
|
|
||||||
|
|
||||||
return checkpoint(self.tail, torch.cat([x, level], dim=1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self._forward_impl(x)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_u_resnet50_2(opt_net, opt):
|
|
||||||
model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
|
||||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
|
||||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
model = UResNet50_2(Bottleneck, [3,4,6,3])
|
|
||||||
samp = torch.rand(1,3,224,224)
|
|
||||||
y = model(samp)
|
|
||||||
print(y.shape)
|
|
||||||
# For pixpro: attach to "tail.3"
|
|
|
@ -1,86 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
|
||||||
from torchvision.models.utils import load_state_dict_from_url
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
from models.arch_util import ConvBnRelu
|
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import checkpoint, opt_get
|
|
||||||
|
|
||||||
|
|
||||||
class UResNet50_3(torchvision.models.resnet.ResNet):
|
|
||||||
|
|
||||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
|
||||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
||||||
norm_layer=None, out_dim=128):
|
|
||||||
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
|
||||||
replace_stride_with_dilation, norm_layer)
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
'''
|
|
||||||
# For reference:
|
|
||||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
||||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[0])
|
|
||||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[1])
|
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[2])
|
|
||||||
'''
|
|
||||||
uplayers = []
|
|
||||||
inplanes = 2048
|
|
||||||
first = True
|
|
||||||
for i in range(3):
|
|
||||||
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
|
|
||||||
inplanes = inplanes // 2
|
|
||||||
first = False
|
|
||||||
self.uplayers = nn.ModuleList(uplayers)
|
|
||||||
|
|
||||||
# These two variables are separated out and renamed so that I can re-use parameters from a pretrained resnet_unet2.
|
|
||||||
self.last_uplayer = ReverseBottleneck(256, 128, norm_layer=norm_layer, passthrough=True)
|
|
||||||
self.tail3 = nn.Sequential(conv1x1(192, 128),
|
|
||||||
norm_layer(128),
|
|
||||||
nn.ReLU(),
|
|
||||||
conv1x1(128, out_dim))
|
|
||||||
|
|
||||||
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
|
||||||
|
|
||||||
|
|
||||||
def _forward_impl(self, x):
|
|
||||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
|
||||||
x = self.maxpool(x0)
|
|
||||||
|
|
||||||
x1 = checkpoint(self.layer1, x)
|
|
||||||
x2 = checkpoint(self.layer2, x1)
|
|
||||||
x3 = checkpoint(self.layer3, x2)
|
|
||||||
x4 = checkpoint(self.layer4, x3)
|
|
||||||
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
|
|
||||||
|
|
||||||
x = checkpoint(self.uplayers[0], x4)
|
|
||||||
x = checkpoint(self.uplayers[1], x, x3)
|
|
||||||
x = checkpoint(self.uplayers[2], x, x2)
|
|
||||||
x = checkpoint(self.last_uplayer, x, x1)
|
|
||||||
|
|
||||||
return checkpoint(self.tail3, torch.cat([x, x0], dim=1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self._forward_impl(x)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_u_resnet50_3(opt_net, opt):
|
|
||||||
model = UResNet50_3(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
|
||||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
|
||||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
model = UResNet50_3(Bottleneck, [3,4,6,3])
|
|
||||||
samp = torch.rand(1,3,224,224)
|
|
||||||
y = model(samp)
|
|
||||||
print(y.shape)
|
|
||||||
# For pixpro: attach to "tail.3"
|
|
|
@ -1,9 +0,0 @@
|
||||||
from models.styled_sr.discriminator import StyleSrGanDivergenceLoss
|
|
||||||
|
|
||||||
|
|
||||||
def create_stylesr_loss(opt_loss, env):
|
|
||||||
type = opt_loss['type']
|
|
||||||
if type == 'style_sr_gan_divergence_loss':
|
|
||||||
return StyleSrGanDivergenceLoss(opt_loss, env)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
|
@ -1,344 +0,0 @@
|
||||||
# Heavily based on the lucidrains stylegan2 discriminator implementation.
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
from math import log2
|
|
||||||
from random import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision
|
|
||||||
from torch.autograd import grad as torch_grad
|
|
||||||
import trainer.losses as L
|
|
||||||
from vector_quantize_pytorch import VectorQuantize
|
|
||||||
|
|
||||||
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
|
|
||||||
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import checkpoint, opt_get
|
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorBlock(nn.Module):
|
|
||||||
def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
self.filters = filters
|
|
||||||
self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
|
|
||||||
leaky_relu(),
|
|
||||||
TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
|
|
||||||
leaky_relu()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.downsample = nn.Sequential(
|
|
||||||
Blur(),
|
|
||||||
TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
|
|
||||||
) if downsample else None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
res = self.conv_res(x)
|
|
||||||
x = self.net(x)
|
|
||||||
if exists(self.downsample):
|
|
||||||
x = self.downsample(x)
|
|
||||||
x = (x + res) * (1 / math.sqrt(2))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class StyleSrDiscriminator(nn.Module):
|
|
||||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
|
||||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
|
|
||||||
transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
num_layers = int(log2(image_size) - 1)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
|
|
||||||
|
|
||||||
set_fmap_max = partial(min, fmap_max)
|
|
||||||
filters = list(map(set_fmap_max, filters))
|
|
||||||
chan_in_out = list(zip(filters[:-1], filters[1:]))
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
attn_blocks = []
|
|
||||||
quantize_blocks = []
|
|
||||||
|
|
||||||
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
|
|
||||||
num_layer = ind + 1
|
|
||||||
is_not_last = ind != (len(chan_in_out) - 1)
|
|
||||||
|
|
||||||
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
|
|
||||||
blocks.append(block)
|
|
||||||
|
|
||||||
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
|
|
||||||
|
|
||||||
attn_blocks.append(attn_fn)
|
|
||||||
|
|
||||||
if quantize:
|
|
||||||
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
|
|
||||||
quantize_blocks.append(quantize_fn)
|
|
||||||
else:
|
|
||||||
quantize_blocks.append(None)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
self.attn_blocks = nn.ModuleList(attn_blocks)
|
|
||||||
self.quantize_blocks = nn.ModuleList(quantize_blocks)
|
|
||||||
self.do_checkpointing = do_checkpointing
|
|
||||||
|
|
||||||
chan_last = filters[-1]
|
|
||||||
latent_dim = 2 * 2 * chan_last
|
|
||||||
|
|
||||||
self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
|
|
||||||
self.flatten = nn.Flatten()
|
|
||||||
if mlp:
|
|
||||||
self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
|
|
||||||
leaky_relu(),
|
|
||||||
TransferLinear(100, 1, transfer_mode=transfer_mode))
|
|
||||||
else:
|
|
||||||
self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
|
|
||||||
|
|
||||||
self._init_weights()
|
|
||||||
|
|
||||||
self.transfer_mode = transfer_mode
|
|
||||||
if transfer_mode:
|
|
||||||
for p in self.parameters():
|
|
||||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
|
||||||
p.DO_NOT_TRAIN = True
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, *_ = x.shape
|
|
||||||
|
|
||||||
quantize_loss = torch.zeros(1).to(x)
|
|
||||||
|
|
||||||
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
|
|
||||||
if self.do_checkpointing:
|
|
||||||
x = checkpoint(block, x)
|
|
||||||
else:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
if exists(attn_block):
|
|
||||||
x = attn_block(x)
|
|
||||||
|
|
||||||
if exists(q_block):
|
|
||||||
x, _, loss = q_block(x)
|
|
||||||
quantize_loss += loss
|
|
||||||
|
|
||||||
x = self.final_conv(x)
|
|
||||||
x = self.flatten(x)
|
|
||||||
x = self.to_logit(x)
|
|
||||||
if exists(q_block):
|
|
||||||
return x.squeeze(), quantize_loss
|
|
||||||
else:
|
|
||||||
return x.squeeze()
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if type(m) in {TransferConv2d, TransferLinear}:
|
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
|
||||||
|
|
||||||
# Configures the network as partially pre-trained. This means:
|
|
||||||
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
|
|
||||||
# 2) The head (linear layers) will also have their weights re-initialized
|
|
||||||
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
|
|
||||||
# These settings will be applied after the weights have been loaded (network_loaded())
|
|
||||||
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
|
|
||||||
self.bypass_blocks = bypass_blocks
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
self.frozen_until_step = frozen_until_step
|
|
||||||
|
|
||||||
# Called after the network weights are loaded.
|
|
||||||
def network_loaded(self):
|
|
||||||
if not hasattr(self, 'frozen_until_step'):
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.bypass_blocks > 0:
|
|
||||||
self.blocks = self.blocks[self.bypass_blocks:]
|
|
||||||
self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device)
|
|
||||||
|
|
||||||
reset_blocks = [self.to_logit]
|
|
||||||
for i in range(self.num_blocks):
|
|
||||||
reset_blocks.append(self.blocks[i])
|
|
||||||
for bl in reset_blocks:
|
|
||||||
for m in bl.modules():
|
|
||||||
if type(m) in {TransferConv2d, TransferLinear}:
|
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
|
||||||
for p in m.parameters(recurse=True):
|
|
||||||
p._NEW_BLOCK = True
|
|
||||||
for p in self.parameters():
|
|
||||||
if not hasattr(p, '_NEW_BLOCK'):
|
|
||||||
p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step
|
|
||||||
|
|
||||||
|
|
||||||
# helper classes
|
|
||||||
def DiffAugment(x, types=[]):
|
|
||||||
for p in types:
|
|
||||||
for f in AUGMENT_FNS[p]:
|
|
||||||
x = f(x)
|
|
||||||
return x.contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
def random_hflip(tensor, prob):
|
|
||||||
if prob > random():
|
|
||||||
return tensor
|
|
||||||
return torch.flip(tensor, dims=(3,))
|
|
||||||
|
|
||||||
|
|
||||||
def rand_brightness(x):
|
|
||||||
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rand_saturation(x):
|
|
||||||
x_mean = x.mean(dim=1, keepdim=True)
|
|
||||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rand_contrast(x):
|
|
||||||
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
|
||||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rand_translation(x, ratio=0.125):
|
|
||||||
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
|
||||||
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
|
||||||
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
|
||||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
|
||||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
|
||||||
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
|
||||||
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
|
||||||
)
|
|
||||||
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
|
||||||
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
|
||||||
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
|
||||||
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rand_cutout(x, ratio=0.5):
|
|
||||||
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
|
||||||
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
|
||||||
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
|
||||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
|
||||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
|
||||||
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
|
||||||
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
|
||||||
)
|
|
||||||
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
|
||||||
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
|
||||||
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
|
||||||
mask[grid_batch, grid_x, grid_y] = 0
|
|
||||||
x = x * mask.unsqueeze(1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
AUGMENT_FNS = {
|
|
||||||
'color': [rand_brightness, rand_saturation, rand_contrast],
|
|
||||||
'translation': [rand_translation],
|
|
||||||
'cutout': [rand_cutout],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DiscAugmentor(nn.Module):
|
|
||||||
def __init__(self, D, image_size, types, prob):
|
|
||||||
super().__init__()
|
|
||||||
self.D = D
|
|
||||||
self.prob = prob
|
|
||||||
self.types = types
|
|
||||||
|
|
||||||
def forward(self, images, real_images=False):
|
|
||||||
if random() < self.prob:
|
|
||||||
images = random_hflip(images, prob=0.5)
|
|
||||||
images = DiffAugment(images, types=self.types)
|
|
||||||
|
|
||||||
if real_images:
|
|
||||||
self.hq_aug = images.detach().clone()
|
|
||||||
else:
|
|
||||||
self.gen_aug = images.detach().clone()
|
|
||||||
|
|
||||||
# Save away for use elsewhere (e.g. unet loss)
|
|
||||||
self.aug_images = images
|
|
||||||
|
|
||||||
return self.D(images)
|
|
||||||
|
|
||||||
def network_loaded(self):
|
|
||||||
self.D.network_loaded()
|
|
||||||
|
|
||||||
# Allows visualizing what the augmentor is up to.
|
|
||||||
def visual_dbg(self, step, path):
|
|
||||||
torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step)))
|
|
||||||
torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step)))
|
|
||||||
|
|
||||||
|
|
||||||
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
|
||||||
if fp16:
|
|
||||||
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
|
|
||||||
scaled_loss.backward(**kwargs)
|
|
||||||
else:
|
|
||||||
loss.backward(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def gradient_penalty(images, output, weight=10, return_structured_grads=False):
|
|
||||||
batch_size = images.shape[0]
|
|
||||||
gradients = torch_grad(outputs=output, inputs=images,
|
|
||||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
|
||||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
|
||||||
|
|
||||||
flat_grad = gradients.reshape(batch_size, -1)
|
|
||||||
penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
|
|
||||||
if return_structured_grads:
|
|
||||||
return penalty, gradients
|
|
||||||
else:
|
|
||||||
return penalty
|
|
||||||
|
|
||||||
|
|
||||||
class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
super().__init__(opt, env)
|
|
||||||
self.real = opt['real']
|
|
||||||
self.fake = opt['fake']
|
|
||||||
self.discriminator = opt['discriminator']
|
|
||||||
self.for_gen = opt['gen_loss']
|
|
||||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
|
||||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
|
||||||
|
|
||||||
def forward(self, net, state):
|
|
||||||
real_input = state[self.real]
|
|
||||||
fake_input = state[self.fake]
|
|
||||||
if self.noise != 0:
|
|
||||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
|
||||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
|
||||||
|
|
||||||
D = self.env['discriminators'][self.discriminator]
|
|
||||||
fake = D(fake_input, real_images=False)
|
|
||||||
if self.for_gen:
|
|
||||||
return fake.mean()
|
|
||||||
else:
|
|
||||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
|
||||||
real = D(real_input, real_images=True)
|
|
||||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
|
||||||
|
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
|
||||||
if self.env['step'] % self.gp_frequency == 0:
|
|
||||||
gp = gradient_penalty(real_input, real)
|
|
||||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
|
||||||
divergence_loss = divergence_loss + gp
|
|
||||||
|
|
||||||
real_input.requires_grad_(requires_grad=False)
|
|
||||||
return divergence_loss
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_styledsr_discriminator(opt_net, opt):
|
|
||||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
|
||||||
disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
|
|
||||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
|
|
||||||
quantize=opt_get(opt_net, ['quantize'], False),
|
|
||||||
mlp=opt_get(opt_net, ['mlp_head'], True),
|
|
||||||
transfer_mode=opt_get(opt_net, ['transfer_mode'], False)
|
|
||||||
)
|
|
||||||
if 'use_partial_pretrained' in opt_net.keys():
|
|
||||||
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
|
|
||||||
return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
|
|
@ -1,199 +0,0 @@
|
||||||
from random import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from models.arch_util import kaiming_init
|
|
||||||
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock
|
|
||||||
from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import checkpoint, opt_get
|
|
||||||
|
|
||||||
|
|
||||||
def rrdb_init_weights(module, scale=1):
|
|
||||||
for m in module.modules():
|
|
||||||
if isinstance(m, TransferConv2d):
|
|
||||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
|
||||||
m.weight.data *= scale
|
|
||||||
elif isinstance(m, TransferLinear):
|
|
||||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
|
||||||
m.weight.data *= scale
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderRRDB(nn.Module):
|
|
||||||
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False):
|
|
||||||
super(EncoderRRDB, self).__init__()
|
|
||||||
for i in range(5):
|
|
||||||
out_channels = output_channels if i == 4 else growth_channels
|
|
||||||
self.add_module(
|
|
||||||
f'conv{i+1}',
|
|
||||||
TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode))
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
for i in range(5):
|
|
||||||
rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5
|
|
||||||
|
|
||||||
|
|
||||||
class StyledSrEncoder(nn.Module):
|
|
||||||
def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
# Current assumes fea_out=256.
|
|
||||||
self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode)
|
|
||||||
self.rrdbs = nn.ModuleList([
|
|
||||||
EncoderRRDB(32, transfer_mode=transfer_mode),
|
|
||||||
EncoderRRDB(64, transfer_mode=transfer_mode),
|
|
||||||
EncoderRRDB(96, transfer_mode=transfer_mode),
|
|
||||||
EncoderRRDB(128, transfer_mode=transfer_mode),
|
|
||||||
EncoderRRDB(160, transfer_mode=transfer_mode),
|
|
||||||
EncoderRRDB(192, transfer_mode=transfer_mode),
|
|
||||||
EncoderRRDB(224, transfer_mode=transfer_mode)])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fea = self.initial_conv(x)
|
|
||||||
for rrdb in self.rrdbs:
|
|
||||||
fea = torch.cat([fea, checkpoint(rrdb, fea)], dim=1)
|
|
||||||
return fea
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
|
|
||||||
self.image_size = image_size
|
|
||||||
self.scale = 2 ** upsample_levels
|
|
||||||
self.latent_dim = latent_dim
|
|
||||||
self.num_layers = total_levels
|
|
||||||
self.transfer_mode = transfer_mode
|
|
||||||
filters = [
|
|
||||||
512, # 4x4
|
|
||||||
512, # 8x8
|
|
||||||
512, # 16x16
|
|
||||||
256, # 32x32
|
|
||||||
128, # 64x64
|
|
||||||
64, # 128x128
|
|
||||||
32, # 256x256
|
|
||||||
16, # 512x512
|
|
||||||
8, # 1024x1024
|
|
||||||
]
|
|
||||||
|
|
||||||
# I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted.
|
|
||||||
self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False)
|
|
||||||
|
|
||||||
in_out_pairs = list(zip(filters[:-1], filters[1:]))
|
|
||||||
self.blocks = nn.ModuleList([])
|
|
||||||
for ind in range(start_level, start_level+total_levels):
|
|
||||||
in_chan, out_chan = in_out_pairs[ind]
|
|
||||||
not_first = ind != start_level
|
|
||||||
not_last = ind != (start_level+total_levels-1)
|
|
||||||
block = GeneratorBlock(
|
|
||||||
latent_dim,
|
|
||||||
in_chan,
|
|
||||||
out_chan,
|
|
||||||
upsample=not_first,
|
|
||||||
upsample_rgb=not_last,
|
|
||||||
transfer_learning_mode=transfer_mode
|
|
||||||
)
|
|
||||||
self.blocks.append(block)
|
|
||||||
|
|
||||||
def forward(self, lr, styles):
|
|
||||||
b, c, h, w = lr.shape
|
|
||||||
if self.transfer_mode:
|
|
||||||
with torch.no_grad():
|
|
||||||
x = self.encoder(lr)
|
|
||||||
else:
|
|
||||||
x = self.encoder(lr)
|
|
||||||
|
|
||||||
styles = styles.transpose(0, 1)
|
|
||||||
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
|
|
||||||
if h != x.shape[-2]:
|
|
||||||
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
|
|
||||||
else:
|
|
||||||
rgb = lr
|
|
||||||
|
|
||||||
for style, block in zip(styles, self.blocks):
|
|
||||||
x, rgb = checkpoint(block, x, rgb, style, input_noise)
|
|
||||||
|
|
||||||
return rgb
|
|
||||||
|
|
||||||
|
|
||||||
class StyledSrGenerator(nn.Module):
|
|
||||||
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
# Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
|
|
||||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
|
|
||||||
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
|
|
||||||
self.l2 = nn.MSELoss()
|
|
||||||
self.mixed_prob = .9
|
|
||||||
self._init_weights()
|
|
||||||
self.transfer_mode = transfer_mode
|
|
||||||
self.initial_stride = initial_stride
|
|
||||||
if transfer_mode:
|
|
||||||
for p in self.parameters():
|
|
||||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
|
||||||
p.DO_NOT_TRAIN = True
|
|
||||||
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'):
|
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
|
||||||
|
|
||||||
for block in self.gen.blocks:
|
|
||||||
nn.init.zeros_(block.to_noise1.weight)
|
|
||||||
nn.init.zeros_(block.to_noise2.weight)
|
|
||||||
nn.init.zeros_(block.to_noise1.bias)
|
|
||||||
nn.init.zeros_(block.to_noise2.bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, f, h, w = x.shape
|
|
||||||
|
|
||||||
# Synthesize style latents from noise.
|
|
||||||
style = torch.randn(b*2, self.gen.latent_dim).to(x.device)
|
|
||||||
if self.transfer_mode:
|
|
||||||
with torch.no_grad():
|
|
||||||
w = self.vectorizer(style)
|
|
||||||
else:
|
|
||||||
w = self.vectorizer(style)
|
|
||||||
|
|
||||||
# Randomly distribute styles across layers
|
|
||||||
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
|
|
||||||
for j in range(b):
|
|
||||||
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
|
|
||||||
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
|
|
||||||
w_styles[j] = w_styles[j*2]
|
|
||||||
else:
|
|
||||||
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
|
|
||||||
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
|
|
||||||
w_styles = w_styles[:b]
|
|
||||||
|
|
||||||
out = self.gen(x, w_styles)
|
|
||||||
|
|
||||||
# Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used
|
|
||||||
# for regularization.
|
|
||||||
out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area")
|
|
||||||
if self.initial_stride > 1:
|
|
||||||
x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area")
|
|
||||||
l2_reg = self.l2(x, out_down)
|
|
||||||
|
|
||||||
return out, l2_reg, w_styles
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
gen = StyledSrGenerator(128, 2)
|
|
||||||
out = gen(torch.rand(1,3,64,64))
|
|
||||||
print([o.shape for o in out])
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_styled_sr(opt_net, opt):
|
|
||||||
return StyledSrGenerator(128,
|
|
||||||
initial_stride=opt_get(opt_net, ['initial_stride'], 1),
|
|
||||||
transfer_mode=opt_get(opt_net, ['transfer_mode'], False))
|
|
|
@ -1,411 +0,0 @@
|
||||||
import math
|
|
||||||
import multiprocessing
|
|
||||||
from contextlib import contextmanager, ExitStack
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from kornia.filters import filter2D
|
|
||||||
from linear_attention_transformer import ImageLinearAttention
|
|
||||||
from torch import nn, Tensor
|
|
||||||
from torch.autograd import grad as torch_grad
|
|
||||||
from torch.nn import Parameter, init
|
|
||||||
from torch.nn.modules.conv import _ConvNd
|
|
||||||
|
|
||||||
from models.styled_sr.transfer_primitives import TransferLinear
|
|
||||||
|
|
||||||
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
|
||||||
|
|
||||||
num_cores = multiprocessing.cpu_count()
|
|
||||||
|
|
||||||
# constants
|
|
||||||
EPS = 1e-8
|
|
||||||
|
|
||||||
|
|
||||||
class NanException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EMA():
|
|
||||||
def __init__(self, beta):
|
|
||||||
super().__init__()
|
|
||||||
self.beta = beta
|
|
||||||
|
|
||||||
def update_average(self, old, new):
|
|
||||||
if not exists(old):
|
|
||||||
return new
|
|
||||||
return old * self.beta + (1 - self.beta) * new
|
|
||||||
|
|
||||||
|
|
||||||
class Flatten(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return x.reshape(x.shape[0], -1)
|
|
||||||
|
|
||||||
|
|
||||||
class Residual(nn.Module):
|
|
||||||
def __init__(self, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.fn(x) + x
|
|
||||||
|
|
||||||
|
|
||||||
class Rezero(nn.Module):
|
|
||||||
def __init__(self, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
self.g = nn.Parameter(torch.zeros(1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.fn(x) * self.g
|
|
||||||
|
|
||||||
|
|
||||||
class PermuteToFrom(nn.Module):
|
|
||||||
def __init__(self, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.permute(0, 2, 3, 1)
|
|
||||||
out, loss = self.fn(x)
|
|
||||||
out = out.permute(0, 3, 1, 2)
|
|
||||||
return out, loss
|
|
||||||
|
|
||||||
|
|
||||||
class Blur(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
f = torch.Tensor([1, 2, 1])
|
|
||||||
self.register_buffer('f', f)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
f = self.f
|
|
||||||
f = f[None, None, :] * f[None, :, None]
|
|
||||||
return filter2D(x, f, normalized=True)
|
|
||||||
|
|
||||||
|
|
||||||
# one layer of self-attention and feedforward, for images
|
|
||||||
|
|
||||||
attn_and_ff = lambda chan: nn.Sequential(*[
|
|
||||||
Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
|
|
||||||
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
# helpers
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def null_context():
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def combine_contexts(contexts):
|
|
||||||
@contextmanager
|
|
||||||
def multi_contexts():
|
|
||||||
with ExitStack() as stack:
|
|
||||||
yield [stack.enter_context(ctx()) for ctx in contexts]
|
|
||||||
|
|
||||||
return multi_contexts
|
|
||||||
|
|
||||||
|
|
||||||
def default(value, d):
|
|
||||||
return value if exists(value) else d
|
|
||||||
|
|
||||||
|
|
||||||
def cycle(iterable):
|
|
||||||
while True:
|
|
||||||
for i in iterable:
|
|
||||||
yield i
|
|
||||||
|
|
||||||
|
|
||||||
def cast_list(el):
|
|
||||||
return el if isinstance(el, list) else [el]
|
|
||||||
|
|
||||||
|
|
||||||
def is_empty(t):
|
|
||||||
if isinstance(t, torch.Tensor):
|
|
||||||
return t.nelement() == 0
|
|
||||||
return not exists(t)
|
|
||||||
|
|
||||||
|
|
||||||
def raise_if_nan(t):
|
|
||||||
if torch.isnan(t):
|
|
||||||
raise NanException
|
|
||||||
|
|
||||||
|
|
||||||
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
|
|
||||||
if is_ddp:
|
|
||||||
num_no_syncs = gradient_accumulate_every - 1
|
|
||||||
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
|
|
||||||
tail = [null_context]
|
|
||||||
contexts = head + tail
|
|
||||||
else:
|
|
||||||
contexts = [null_context] * gradient_accumulate_every
|
|
||||||
|
|
||||||
for context in contexts:
|
|
||||||
with context():
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
|
||||||
if fp16:
|
|
||||||
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
|
|
||||||
scaled_loss.backward(**kwargs)
|
|
||||||
else:
|
|
||||||
loss.backward(**kwargs)
|
|
||||||
|
|
||||||
def calc_pl_lengths(styles, images):
|
|
||||||
device = images.device
|
|
||||||
num_pixels = images.shape[2] * images.shape[3]
|
|
||||||
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
|
|
||||||
outputs = (images * pl_noise).sum()
|
|
||||||
|
|
||||||
pl_grads = torch_grad(outputs=outputs, inputs=styles,
|
|
||||||
grad_outputs=torch.ones(outputs.shape, device=device),
|
|
||||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
|
||||||
|
|
||||||
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
|
|
||||||
|
|
||||||
|
|
||||||
def image_noise(n, im_size, device):
|
|
||||||
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device)
|
|
||||||
|
|
||||||
|
|
||||||
def leaky_relu(p=0.2):
|
|
||||||
return nn.LeakyReLU(p, inplace=True)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_in_chunks(max_batch_size, model, *args):
|
|
||||||
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
|
|
||||||
chunked_outputs = [model(*i) for i in split_args]
|
|
||||||
if len(chunked_outputs) == 1:
|
|
||||||
return chunked_outputs[0]
|
|
||||||
return torch.cat(chunked_outputs, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def set_requires_grad(model, bool):
|
|
||||||
for p in model.parameters():
|
|
||||||
p.requires_grad = bool
|
|
||||||
|
|
||||||
|
|
||||||
def slerp(val, low, high):
|
|
||||||
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
|
||||||
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
|
||||||
omega = torch.acos((low_norm * high_norm).sum(1))
|
|
||||||
so = torch.sin(omega)
|
|
||||||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class EqualLinear(nn.Module):
|
|
||||||
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim))
|
|
||||||
|
|
||||||
self.lr_mul = lr_mul
|
|
||||||
|
|
||||||
self.transfer_mode = transfer_mode
|
|
||||||
if transfer_mode:
|
|
||||||
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
|
||||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
|
||||||
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
|
||||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
if self.transfer_mode:
|
|
||||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
|
||||||
else:
|
|
||||||
weight = self.weight
|
|
||||||
return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
|
||||||
|
|
||||||
|
|
||||||
class StyleVectorizer(nn.Module):
|
|
||||||
def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
for i in range(depth):
|
|
||||||
layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()])
|
|
||||||
|
|
||||||
self.net = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.normalize(x, dim=1)
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class RGBBlock(nn.Module):
|
|
||||||
def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
self.input_channel = input_channel
|
|
||||||
self.to_style = nn.Linear(latent_dim, input_channel)
|
|
||||||
|
|
||||||
out_filters = 3 if not rgba else 4
|
|
||||||
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode)
|
|
||||||
|
|
||||||
self.upsample = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
|
||||||
Blur()
|
|
||||||
) if upsample else None
|
|
||||||
|
|
||||||
def forward(self, x, prev_rgb, istyle):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
style = self.to_style(istyle)
|
|
||||||
x = self.conv(x, style)
|
|
||||||
|
|
||||||
if exists(prev_rgb):
|
|
||||||
x = x + prev_rgb
|
|
||||||
|
|
||||||
if exists(self.upsample):
|
|
||||||
x = self.upsample(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveInstanceNorm(nn.Module):
|
|
||||||
def __init__(self, in_channel, style_dim):
|
|
||||||
super().__init__()
|
|
||||||
from models.archs.arch_util import ConvGnLelu
|
|
||||||
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
|
|
||||||
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
|
|
||||||
self.norm = nn.InstanceNorm2d(in_channel)
|
|
||||||
|
|
||||||
def forward(self, input, style):
|
|
||||||
gamma = self.style2scale(style)
|
|
||||||
beta = self.style2bias(style)
|
|
||||||
out = self.norm(input)
|
|
||||||
out = gamma * out + beta
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class NoiseInjection(nn.Module):
|
|
||||||
def __init__(self, channel):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
|
||||||
|
|
||||||
def forward(self, image, noise):
|
|
||||||
return image + self.weight * noise
|
|
||||||
|
|
||||||
|
|
||||||
class EqualLR:
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def compute_weight(self, module):
|
|
||||||
weight = getattr(module, self.name + '_orig')
|
|
||||||
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
|
||||||
|
|
||||||
return weight * math.sqrt(2 / fan_in)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def apply(module, name):
|
|
||||||
fn = EqualLR(name)
|
|
||||||
|
|
||||||
weight = getattr(module, name)
|
|
||||||
del module._parameters[name]
|
|
||||||
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
|
||||||
module.register_forward_pre_hook(fn)
|
|
||||||
|
|
||||||
return fn
|
|
||||||
|
|
||||||
def __call__(self, module, input):
|
|
||||||
weight = self.compute_weight(module)
|
|
||||||
setattr(module, self.name, weight)
|
|
||||||
|
|
||||||
|
|
||||||
def equal_lr(module, name='weight'):
|
|
||||||
EqualLR.apply(module, name)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2DMod(nn.Module):
|
|
||||||
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.filters = out_chan
|
|
||||||
self.demod = demod
|
|
||||||
self.kernel = kernel
|
|
||||||
self.stride = stride
|
|
||||||
self.dilation = dilation
|
|
||||||
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
|
||||||
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
|
||||||
self.transfer_mode = transfer_mode
|
|
||||||
if transfer_mode:
|
|
||||||
self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1))
|
|
||||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
|
||||||
self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1))
|
|
||||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
|
||||||
|
|
||||||
def _get_same_padding(self, size, kernel, dilation, stride):
|
|
||||||
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
if self.transfer_mode:
|
|
||||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
|
||||||
else:
|
|
||||||
weight = self.weight
|
|
||||||
|
|
||||||
w1 = y[:, None, :, None, None]
|
|
||||||
w2 = weight[None, :, :, :, :]
|
|
||||||
weights = w2 * (w1 + 1)
|
|
||||||
|
|
||||||
if self.demod:
|
|
||||||
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
|
|
||||||
weights = weights * d
|
|
||||||
|
|
||||||
x = x.reshape(1, -1, h, w)
|
|
||||||
|
|
||||||
_, _, *ws = weights.shape
|
|
||||||
weights = weights.reshape(b * self.filters, *ws)
|
|
||||||
|
|
||||||
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
|
|
||||||
x = F.conv2d(x, weights, padding=padding, groups=b)
|
|
||||||
|
|
||||||
x = x.reshape(-1, self.filters, h, w)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorBlock(nn.Module):
|
|
||||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False,
|
|
||||||
transfer_learning_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
|
||||||
|
|
||||||
self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode)
|
|
||||||
self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
|
||||||
self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode)
|
|
||||||
|
|
||||||
self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode)
|
|
||||||
self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
|
||||||
self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode)
|
|
||||||
|
|
||||||
self.activation = leaky_relu()
|
|
||||||
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode)
|
|
||||||
|
|
||||||
self.transfer_learning_mode = transfer_learning_mode
|
|
||||||
|
|
||||||
def forward(self, x, prev_rgb, istyle, inoise):
|
|
||||||
if exists(self.upsample):
|
|
||||||
x = self.upsample(x)
|
|
||||||
|
|
||||||
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
|
||||||
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
|
|
||||||
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
|
|
||||||
|
|
||||||
style1 = self.to_style1(istyle)
|
|
||||||
x = self.conv1(x, style1)
|
|
||||||
x = self.activation(x + noise1)
|
|
||||||
|
|
||||||
style2 = self.to_style2(istyle)
|
|
||||||
x = self.conv2(x, style2)
|
|
||||||
x = self.activation(x + noise2)
|
|
||||||
|
|
||||||
rgb = self.to_rgb(x, prev_rgb, istyle)
|
|
||||||
return x, rgb
|
|
|
@ -1,136 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.nn import Parameter, init
|
|
||||||
from torch.nn.modules.conv import _ConvNd
|
|
||||||
from torch.nn.modules.utils import _ntuple
|
|
||||||
|
|
||||||
_pair = _ntuple(2)
|
|
||||||
|
|
||||||
class TransferConv2d(_ConvNd):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size,
|
|
||||||
stride = 1,
|
|
||||||
padding = 0,
|
|
||||||
dilation = 1,
|
|
||||||
groups: int = 1,
|
|
||||||
bias: bool = True,
|
|
||||||
padding_mode: str = 'zeros',
|
|
||||||
transfer_mode: bool = False
|
|
||||||
):
|
|
||||||
kernel_size = _pair(kernel_size)
|
|
||||||
stride = _pair(stride)
|
|
||||||
padding = _pair(padding)
|
|
||||||
dilation = _pair(dilation)
|
|
||||||
super().__init__(
|
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
|
||||||
False, _pair(0), groups, bias, padding_mode)
|
|
||||||
|
|
||||||
self.transfer_mode = transfer_mode
|
|
||||||
if transfer_mode:
|
|
||||||
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
|
|
||||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
|
||||||
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
|
|
||||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
|
||||||
|
|
||||||
def _conv_forward(self, input, weight):
|
|
||||||
if self.transfer_mode:
|
|
||||||
weight = weight * self.transfer_scale + self.transfer_shift
|
|
||||||
else:
|
|
||||||
weight = weight
|
|
||||||
|
|
||||||
if self.padding_mode != 'zeros':
|
|
||||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
|
||||||
weight, self.bias, self.stride,
|
|
||||||
_pair(0), self.dilation, self.groups)
|
|
||||||
return F.conv2d(input, weight, self.bias, self.stride,
|
|
||||||
self.padding, self.dilation, self.groups)
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
return self._conv_forward(input, self.weight)
|
|
||||||
|
|
||||||
|
|
||||||
class TransferLinear(nn.Module):
|
|
||||||
__constants__ = ['in_features', 'out_features']
|
|
||||||
in_features: int
|
|
||||||
out_features: int
|
|
||||||
weight: Tensor
|
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
|
||||||
if bias:
|
|
||||||
self.bias = Parameter(torch.Tensor(out_features))
|
|
||||||
else:
|
|
||||||
self.register_parameter('bias', None)
|
|
||||||
self.reset_parameters()
|
|
||||||
self.transfer_mode = transfer_mode
|
|
||||||
if transfer_mode:
|
|
||||||
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
|
||||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
|
||||||
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
|
||||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
|
||||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
||||||
if self.bias is not None:
|
|
||||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
||||||
bound = 1 / math.sqrt(fan_in)
|
|
||||||
init.uniform_(self.bias, -bound, bound)
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
if self.transfer_mode:
|
|
||||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
|
||||||
else:
|
|
||||||
weight = self.weight
|
|
||||||
return F.linear(input, weight, self.bias)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return 'in_features={}, out_features={}, bias={}'.format(
|
|
||||||
self.in_features, self.out_features, self.bias is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TransferConvGnLelu(nn.Module):
|
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
|
|
||||||
super().__init__()
|
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
||||||
assert kernel_size in padding_map.keys()
|
|
||||||
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
|
|
||||||
if norm:
|
|
||||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
|
||||||
else:
|
|
||||||
self.gn = None
|
|
||||||
if activation:
|
|
||||||
self.lelu = nn.LeakyReLU(negative_slope=.2)
|
|
||||||
else:
|
|
||||||
self.lelu = None
|
|
||||||
|
|
||||||
# Init params.
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, TransferConv2d):
|
|
||||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
|
||||||
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
|
||||||
m.weight.data *= weight_init_factor
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
||||||
nn.init.constant_(m.weight, 1)
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
if self.gn:
|
|
||||||
x = self.gn(x)
|
|
||||||
if self.lelu:
|
|
||||||
return self.lelu(x)
|
|
||||||
else:
|
|
||||||
return x
|
|
|
@ -1,15 +0,0 @@
|
||||||
import munch
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from trainer.networks import register_model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_flownet2(opt_net):
|
|
||||||
from models.flownet2.models import FlowNet2
|
|
||||||
ld = 'load_path' in opt_net.keys()
|
|
||||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
|
||||||
netG = FlowNet2(args)
|
|
||||||
if ld:
|
|
||||||
sd = torch.load(opt_net['load_path'])
|
|
||||||
netG.load_state_dict(sd['state_dict'])
|
|
|
@ -1,79 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import sequential_checkpoint
|
|
||||||
from models.arch_util import ConvGnSilu, make_layer
|
|
||||||
|
|
||||||
|
|
||||||
class TecoResblock(nn.Module):
|
|
||||||
def __init__(self, nf):
|
|
||||||
super(TecoResblock, self).__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
|
|
||||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
identity = x
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
return identity + x
|
|
||||||
|
|
||||||
|
|
||||||
class TecoUpconv(nn.Module):
|
|
||||||
def __init__(self, nf, scale):
|
|
||||||
super(TecoUpconv, self).__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.scale = scale
|
|
||||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
|
||||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
|
||||||
self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
|
||||||
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
|
||||||
x = self.conv3(x)
|
|
||||||
return self.final_conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
|
|
||||||
# Main differences:
|
|
||||||
# - Uses SiLU instead of ReLU
|
|
||||||
# - Reference input is in HR space (just makes more sense)
|
|
||||||
# - Doesn't use transposed convolutions - just uses interpolation instead.
|
|
||||||
# - Upsample block is slightly more complicated.
|
|
||||||
class TecoGen(nn.Module):
|
|
||||||
def __init__(self, nf, scale):
|
|
||||||
super(TecoGen, self).__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.scale = scale
|
|
||||||
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
|
|
||||||
res_layers = [TecoResblock(nf) for i in range(15)]
|
|
||||||
upsample = TecoUpconv(nf, scale)
|
|
||||||
everything = [fea_conv] + res_layers + [upsample]
|
|
||||||
self.core = nn.Sequential(*everything)
|
|
||||||
|
|
||||||
def forward(self, x, ref=None):
|
|
||||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
|
||||||
if ref is None:
|
|
||||||
ref = torch.zeros_like(x)
|
|
||||||
join = torch.cat([x, ref], dim=1)
|
|
||||||
join = sequential_checkpoint(self.core, 6, join)
|
|
||||||
self.join = join.detach().clone() + .5
|
|
||||||
return x + join
|
|
||||||
|
|
||||||
def visual_dbg(self, step, path):
|
|
||||||
torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,)))
|
|
||||||
|
|
||||||
def get_debug_values(self, step, net_name):
|
|
||||||
return {'branch_std': self.join.std()}
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_tecogen(opt_net, opt):
|
|
||||||
return TecoGen(opt_net['nf'], opt_net['scale'])
|
|
|
@ -1,155 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from trainer.inject import Injector
|
|
||||||
from trainer.networks import register_model
|
|
||||||
from utils.util import checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def create_injector(opt, env):
|
|
||||||
type = opt['type']
|
|
||||||
if type == 'igpt_resolve':
|
|
||||||
return ResolveInjector(opt, env)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ResolveInjector(Injector):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
super().__init__(opt, env)
|
|
||||||
self.gen = opt['generator']
|
|
||||||
self.samples = opt['num_samples']
|
|
||||||
self.temperature = opt['temperature']
|
|
||||||
|
|
||||||
def forward(self, state):
|
|
||||||
gen = self.env['generators'][self.opt['generator']].module
|
|
||||||
img = state[self.input]
|
|
||||||
b, c, h, w = img.shape
|
|
||||||
qimg = gen.quantize(img)
|
|
||||||
s, b = qimg.shape
|
|
||||||
qimg = qimg[:s//2, :]
|
|
||||||
output = qimg.repeat(1, self.samples)
|
|
||||||
|
|
||||||
pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(s//2):
|
|
||||||
logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
|
|
||||||
logits = logits[-1, :, :] / self.temperature
|
|
||||||
probs = F.softmax(logits, dim=-1)
|
|
||||||
pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
|
|
||||||
output = torch.cat((output, pred), dim=0)
|
|
||||||
output = gen.unquantize(output.reshape(h, w, -1))
|
|
||||||
return {self.output: output.permute(2,3,0,1).contiguous()}
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, embed_dim, num_heads):
|
|
||||||
super(Block, self).__init__()
|
|
||||||
self.ln_1 = nn.LayerNorm(embed_dim)
|
|
||||||
self.ln_2 = nn.LayerNorm(embed_dim)
|
|
||||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(embed_dim, embed_dim * 4),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(embed_dim * 4, embed_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
attn_mask = torch.full(
|
|
||||||
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
attn_mask = torch.triu(attn_mask, diagonal=1)
|
|
||||||
|
|
||||||
x = self.ln_1(x)
|
|
||||||
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
|
|
||||||
x = x + a
|
|
||||||
m = self.mlp(self.ln_2(x))
|
|
||||||
x = x + m
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class iGPT2(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.centroids = nn.Parameter(
|
|
||||||
torch.from_numpy(np.load(centroids_file)), requires_grad=False
|
|
||||||
)
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
# start of sequence token
|
|
||||||
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
|
|
||||||
nn.init.normal_(self.sos)
|
|
||||||
|
|
||||||
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
|
|
||||||
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for _ in range(num_layers):
|
|
||||||
self.layers.append(Block(embed_dim, num_heads))
|
|
||||||
|
|
||||||
self.ln_f = nn.LayerNorm(embed_dim)
|
|
||||||
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
|
|
||||||
self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier.
|
|
||||||
|
|
||||||
def squared_euclidean_distance(self, a, b):
|
|
||||||
b = torch.transpose(b, 0, 1)
|
|
||||||
a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
|
|
||||||
b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
|
|
||||||
ab = torch.matmul(a, b)
|
|
||||||
d = a2 - 2 * ab + b2
|
|
||||||
return d
|
|
||||||
|
|
||||||
def quantize(self, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
# [B, C, H, W] => [B, H, W, C]
|
|
||||||
x = x.permute(0, 2, 3, 1).contiguous()
|
|
||||||
x = x.view(-1, c) # flatten to pixels
|
|
||||||
d = self.squared_euclidean_distance(x, self.centroids)
|
|
||||||
x = torch.argmin(d, 1)
|
|
||||||
x = x.view(b, h, w)
|
|
||||||
|
|
||||||
# Reshape output to [seq_len, batch].
|
|
||||||
x = x.view(x.shape[0], -1) # flatten images into sequences
|
|
||||||
x = x.transpose(0, 1).contiguous() # to shape [seq len, batch]
|
|
||||||
return x
|
|
||||||
|
|
||||||
def unquantize(self, x):
|
|
||||||
return self.centroids[x]
|
|
||||||
|
|
||||||
def forward(self, x, already_quantized=False):
|
|
||||||
"""
|
|
||||||
Expect input as shape [b, c, h, w]
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not already_quantized:
|
|
||||||
x = self.quantize(x)
|
|
||||||
length, batch = x.shape
|
|
||||||
|
|
||||||
h = self.token_embeddings(x)
|
|
||||||
|
|
||||||
# prepend sos token
|
|
||||||
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
|
|
||||||
h = torch.cat([sos, h[:-1, :, :]], axis=0)
|
|
||||||
|
|
||||||
# add positional embeddings
|
|
||||||
positions = torch.arange(length, device=x.device).unsqueeze(-1)
|
|
||||||
h = h + self.position_embeddings(positions).expand_as(h)
|
|
||||||
|
|
||||||
# transformer
|
|
||||||
for layer in self.layers:
|
|
||||||
h = checkpoint(layer, h)
|
|
||||||
|
|
||||||
h = self.ln_f(h)
|
|
||||||
|
|
||||||
logits = self.head(h)
|
|
||||||
|
|
||||||
return logits, x
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_igpt2(opt_net, opt):
|
|
||||||
return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2,
|
|
||||||
opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
|
|
@ -1,42 +0,0 @@
|
||||||
import numpy
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from data.torch_dataset import TorchDataset
|
|
||||||
from models.classifiers.cifar_resnet_branched import ResNet
|
|
||||||
from models.classifiers.cifar_resnet_branched import BasicBlock
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
dopt = {
|
|
||||||
'flip': True,
|
|
||||||
'crop_sz': None,
|
|
||||||
'dataset': 'cifar100',
|
|
||||||
'image_size': 32,
|
|
||||||
'normalize': False,
|
|
||||||
'kwargs': {
|
|
||||||
'root': 'E:\\4k6k\\datasets\\images\\cifar100',
|
|
||||||
'download': True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
set = TorchDataset(dopt)
|
|
||||||
loader = DataLoader(set, num_workers=0, batch_size=32)
|
|
||||||
model = ResNet(BasicBlock, [2, 2, 2, 2])
|
|
||||||
model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardw_10000.pth'))
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
bins = [[] for _ in range(8)]
|
|
||||||
for i, batch in enumerate(loader):
|
|
||||||
logits, selector = model(batch['hq'], coarse_label=None, return_selector=True)
|
|
||||||
for k, s in enumerate(selector):
|
|
||||||
for j, b in enumerate(s):
|
|
||||||
if b:
|
|
||||||
bins[j].append(batch['labels'][k].item())
|
|
||||||
if i > 10:
|
|
||||||
break
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
fig, axs = plt.subplots(3,3)
|
|
||||||
for i in range(8):
|
|
||||||
axs[i%3, i//3].hist(numpy.asarray(bins[i]))
|
|
||||||
plt.show()
|
|
||||||
print('hi')
|
|
|
@ -1,70 +0,0 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from utils import options as option
|
|
||||||
from data import create_dataloader, create_dataset
|
|
||||||
import math
|
|
||||||
from tqdm import tqdm
|
|
||||||
from utils.fdpl_util import dct_2d, extract_patches_2d
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
||||||
from utils.colors import rgb2ycbcr
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
|
|
||||||
output_file = "fdpr_diff_means.pt"
|
|
||||||
device = 'cuda'
|
|
||||||
patch_size=128
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
opt = option.parse(input_config, is_train=True)
|
|
||||||
opt['dist'] = False
|
|
||||||
|
|
||||||
# Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
|
|
||||||
dataset_opt = opt['datasets']['train']
|
|
||||||
train_set = create_dataset(dataset_opt)
|
|
||||||
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
|
||||||
total_iters = int(opt['train']['niter'])
|
|
||||||
total_epochs = int(math.ceil(total_iters / train_size))
|
|
||||||
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
|
|
||||||
print('Number of train images: {:,d}, iters: {:,d}'.format(
|
|
||||||
len(train_set), train_size))
|
|
||||||
|
|
||||||
# calculate the perceptual weights
|
|
||||||
master_diff = np.zeros((patch_size, patch_size))
|
|
||||||
num_patches = 0
|
|
||||||
all_diff_patches = []
|
|
||||||
tq = tqdm(train_loader)
|
|
||||||
sampled = 0
|
|
||||||
for train_data in tq:
|
|
||||||
if sampled > 200:
|
|
||||||
break
|
|
||||||
sampled += 1
|
|
||||||
|
|
||||||
im = rgb2ycbcr(train_data['hq'].double())
|
|
||||||
im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
|
|
||||||
size=im.shape[2:],
|
|
||||||
mode="bicubic", align_corners=False))
|
|
||||||
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
|
||||||
patches_hr = dct_2d(patches_hr, norm='ortho')
|
|
||||||
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
|
||||||
patches_lr = dct_2d(patches_lr, norm='ortho')
|
|
||||||
b, p, c, w, h = patches_hr.shape
|
|
||||||
diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001)
|
|
||||||
num_patches += b * p
|
|
||||||
all_diff_patches.append(torch.sum(diffs, dim=(0, 1)))
|
|
||||||
|
|
||||||
diff_patches = torch.stack(all_diff_patches, dim=0)
|
|
||||||
diff_means = torch.sum(diff_patches, dim=0) / num_patches
|
|
||||||
|
|
||||||
torch.save(diff_means, output_file)
|
|
||||||
print(diff_means)
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
divider = make_axes_locatable(ax)
|
|
||||||
cax = divider.append_axes('right', size='5%', pad=0.05)
|
|
||||||
im = ax.imshow(diff_means[i].numpy())
|
|
||||||
ax.set_title("mean_diff for channel %i" % (i,))
|
|
||||||
fig.colorbar(im, cax=cax, orientation='vertical')
|
|
||||||
plt.show()
|
|
||||||
|
|
|
@ -1,411 +0,0 @@
|
||||||
"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os.path as osp
|
|
||||||
import glob
|
|
||||||
import pickle
|
|
||||||
from multiprocessing import Pool
|
|
||||||
import numpy as np
|
|
||||||
import lmdb
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
|
||||||
import data.util as data_util # noqa: E402
|
|
||||||
import utils.util as util # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
|
|
||||||
mode = 'hq' # used for vimeo90k and REDS datasets
|
|
||||||
# vimeo90k: GT | LR | flow
|
|
||||||
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
|
|
||||||
# train_sharp_flowx4
|
|
||||||
if dataset == 'vimeo90k':
|
|
||||||
vimeo90k(mode)
|
|
||||||
elif dataset == 'REDS':
|
|
||||||
REDS(mode)
|
|
||||||
elif dataset == 'general':
|
|
||||||
opt = {}
|
|
||||||
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
|
||||||
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
|
||||||
opt['name'] = 'DIV2K800_sub_GT'
|
|
||||||
general_image_folder(opt)
|
|
||||||
elif dataset == 'DIV2K_demo':
|
|
||||||
opt = {}
|
|
||||||
## GT
|
|
||||||
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
|
||||||
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
|
||||||
opt['name'] = 'DIV2K800_sub_GT'
|
|
||||||
general_image_folder(opt)
|
|
||||||
## LR
|
|
||||||
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
|
|
||||||
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
|
||||||
opt['name'] = 'DIV2K800_sub_bicLRx4'
|
|
||||||
general_image_folder(opt)
|
|
||||||
elif dataset == 'test':
|
|
||||||
test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS')
|
|
||||||
|
|
||||||
|
|
||||||
def read_image_worker(path, key):
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
||||||
return (key, img)
|
|
||||||
|
|
||||||
|
|
||||||
def general_image_folder(opt):
|
|
||||||
"""Create lmdb for general image folders
|
|
||||||
Users should define the keys, such as: '0321_s035' for DIV2K sub-images
|
|
||||||
If all the images have the same resolution, it will only store one copy of resolution info.
|
|
||||||
Otherwise, it will store every resolution info.
|
|
||||||
"""
|
|
||||||
#### configurations
|
|
||||||
read_all_imgs = False # whether real all images to memory with multiprocessing
|
|
||||||
# Set False for use limited memory
|
|
||||||
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
|
||||||
n_thread = 40
|
|
||||||
########################################################
|
|
||||||
img_folder = opt['img_folder']
|
|
||||||
lmdb_save_path = opt['lmdb_save_path']
|
|
||||||
meta_info = {'name': opt['name']}
|
|
||||||
if not lmdb_save_path.endswith('.lmdb'):
|
|
||||||
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
|
||||||
if osp.exists(lmdb_save_path):
|
|
||||||
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
#### read all the image paths to a list
|
|
||||||
print('Reading image path list ...')
|
|
||||||
all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
|
|
||||||
keys = []
|
|
||||||
for img_path in all_img_list:
|
|
||||||
keys.append(osp.splitext(osp.basename(img_path))[0])
|
|
||||||
|
|
||||||
if read_all_imgs:
|
|
||||||
#### read all images to memory (multiprocessing)
|
|
||||||
dataset = {} # store all image data. list cannot keep the order, use dict
|
|
||||||
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
|
||||||
pbar = util.ProgressBar(len(all_img_list))
|
|
||||||
|
|
||||||
def mycallback(arg):
|
|
||||||
'''get the image data and update pbar'''
|
|
||||||
key = arg[0]
|
|
||||||
dataset[key] = arg[1]
|
|
||||||
pbar.update('Reading {}'.format(key))
|
|
||||||
|
|
||||||
pool = Pool(n_thread)
|
|
||||||
for path, key in zip(all_img_list, keys):
|
|
||||||
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
|
||||||
pool.close()
|
|
||||||
pool.join()
|
|
||||||
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
|
||||||
|
|
||||||
#### create lmdb environment
|
|
||||||
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
|
||||||
print('data size per image is: ', data_size_per_img)
|
|
||||||
data_size = data_size_per_img * len(all_img_list)
|
|
||||||
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
|
||||||
|
|
||||||
#### write data to lmdb
|
|
||||||
pbar = util.ProgressBar(len(all_img_list))
|
|
||||||
txn = env.begin(write=True)
|
|
||||||
resolutions = []
|
|
||||||
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
|
||||||
pbar.update('Write {}'.format(key))
|
|
||||||
key_byte = key.encode('ascii')
|
|
||||||
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
||||||
if data.ndim == 2:
|
|
||||||
H, W = data.shape
|
|
||||||
C = 1
|
|
||||||
else:
|
|
||||||
H, W, C = data.shape
|
|
||||||
txn.put(key_byte, data)
|
|
||||||
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
|
|
||||||
if not read_all_imgs and idx % BATCH == 0:
|
|
||||||
txn.commit()
|
|
||||||
txn = env.begin(write=True)
|
|
||||||
txn.commit()
|
|
||||||
env.close()
|
|
||||||
print('Finish writing lmdb.')
|
|
||||||
|
|
||||||
#### create meta information
|
|
||||||
# check whether all the images are the same size
|
|
||||||
assert len(keys) == len(resolutions)
|
|
||||||
if len(set(resolutions)) <= 1:
|
|
||||||
meta_info['resolution'] = [resolutions[0]]
|
|
||||||
meta_info['keys'] = keys
|
|
||||||
print('All images have the same resolution. Simplify the meta info.')
|
|
||||||
else:
|
|
||||||
meta_info['resolution'] = resolutions
|
|
||||||
meta_info['keys'] = keys
|
|
||||||
print('Not all images have the same resolution. Save meta info for each image.')
|
|
||||||
|
|
||||||
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
|
||||||
print('Finish creating lmdb meta info.')
|
|
||||||
|
|
||||||
|
|
||||||
def vimeo90k(mode):
|
|
||||||
"""Create lmdb for the Vimeo90K dataset, each image with a fixed size
|
|
||||||
GT: [3, 256, 448]
|
|
||||||
Now only need the 4th frame, e.g., 00001_0001_4
|
|
||||||
LR: [3, 64, 112]
|
|
||||||
1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
|
|
||||||
key:
|
|
||||||
Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
|
|
||||||
|
|
||||||
flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
|
|
||||||
Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
|
|
||||||
Flow map is quantized by mmcv and saved in png format
|
|
||||||
"""
|
|
||||||
#### configurations
|
|
||||||
read_all_imgs = False # whether real all images to memory with multiprocessing
|
|
||||||
# Set False for use limited memory
|
|
||||||
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
|
||||||
if mode == 'hq':
|
|
||||||
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
|
|
||||||
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
|
||||||
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
|
||||||
H_dst, W_dst = 256, 448
|
|
||||||
elif mode == 'LR':
|
|
||||||
img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
|
|
||||||
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
|
||||||
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
|
||||||
H_dst, W_dst = 64, 112
|
|
||||||
elif mode == 'flow':
|
|
||||||
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
|
|
||||||
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
|
|
||||||
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
|
||||||
H_dst, W_dst = 128, 112
|
|
||||||
else:
|
|
||||||
raise ValueError('Wrong dataset mode: {}'.format(mode))
|
|
||||||
n_thread = 40
|
|
||||||
########################################################
|
|
||||||
if not lmdb_save_path.endswith('.lmdb'):
|
|
||||||
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
|
||||||
if osp.exists(lmdb_save_path):
|
|
||||||
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
#### read all the image paths to a list
|
|
||||||
print('Reading image path list ...')
|
|
||||||
with open(txt_file) as f:
|
|
||||||
train_l = f.readlines()
|
|
||||||
train_l = [v.strip() for v in train_l]
|
|
||||||
all_img_list = []
|
|
||||||
keys = []
|
|
||||||
for line in train_l:
|
|
||||||
folder = line.split('/')[0]
|
|
||||||
sub_folder = line.split('/')[1]
|
|
||||||
all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
|
|
||||||
if mode == 'flow':
|
|
||||||
for j in range(1, 4):
|
|
||||||
keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
|
|
||||||
keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
|
|
||||||
else:
|
|
||||||
for j in range(7):
|
|
||||||
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
|
|
||||||
all_img_list = sorted(all_img_list)
|
|
||||||
keys = sorted(keys)
|
|
||||||
if mode == 'hq': # only read the 4th frame for the GT mode
|
|
||||||
print('Only keep the 4th frame.')
|
|
||||||
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
|
|
||||||
keys = [v for v in keys if v.endswith('_4')]
|
|
||||||
|
|
||||||
if read_all_imgs:
|
|
||||||
#### read all images to memory (multiprocessing)
|
|
||||||
dataset = {} # store all image data. list cannot keep the order, use dict
|
|
||||||
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
|
||||||
pbar = util.ProgressBar(len(all_img_list))
|
|
||||||
|
|
||||||
def mycallback(arg):
|
|
||||||
"""get the image data and update pbar"""
|
|
||||||
key = arg[0]
|
|
||||||
dataset[key] = arg[1]
|
|
||||||
pbar.update('Reading {}'.format(key))
|
|
||||||
|
|
||||||
pool = Pool(n_thread)
|
|
||||||
for path, key in zip(all_img_list, keys):
|
|
||||||
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
|
||||||
pool.close()
|
|
||||||
pool.join()
|
|
||||||
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
|
||||||
|
|
||||||
#### write data to lmdb
|
|
||||||
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
|
||||||
print('data size per image is: ', data_size_per_img)
|
|
||||||
data_size = data_size_per_img * len(all_img_list)
|
|
||||||
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
|
||||||
txn = env.begin(write=True)
|
|
||||||
pbar = util.ProgressBar(len(all_img_list))
|
|
||||||
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
|
||||||
pbar.update('Write {}'.format(key))
|
|
||||||
key_byte = key.encode('ascii')
|
|
||||||
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
||||||
if 'flow' in mode:
|
|
||||||
H, W = data.shape
|
|
||||||
assert H == H_dst and W == W_dst, 'different shape.'
|
|
||||||
else:
|
|
||||||
H, W, C = data.shape
|
|
||||||
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
|
||||||
txn.put(key_byte, data)
|
|
||||||
if not read_all_imgs and idx % BATCH == 0:
|
|
||||||
txn.commit()
|
|
||||||
txn = env.begin(write=True)
|
|
||||||
txn.commit()
|
|
||||||
env.close()
|
|
||||||
print('Finish writing lmdb.')
|
|
||||||
|
|
||||||
#### create meta information
|
|
||||||
meta_info = {}
|
|
||||||
if mode == 'hq':
|
|
||||||
meta_info['name'] = 'Vimeo90K_train_GT'
|
|
||||||
elif mode == 'lq':
|
|
||||||
meta_info['name'] = 'Vimeo90K_train_LR'
|
|
||||||
elif mode == 'flow':
|
|
||||||
meta_info['name'] = 'Vimeo90K_train_flowx4'
|
|
||||||
channel = 1 if 'flow' in mode else 3
|
|
||||||
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
|
||||||
key_set = set()
|
|
||||||
for key in keys:
|
|
||||||
if mode == 'flow':
|
|
||||||
a, b, _, _ = key.split('_')
|
|
||||||
else:
|
|
||||||
a, b, _ = key.split('_')
|
|
||||||
key_set.add('{}_{}'.format(a, b))
|
|
||||||
meta_info['keys'] = list(key_set)
|
|
||||||
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
|
||||||
print('Finish creating lmdb meta info.')
|
|
||||||
|
|
||||||
|
|
||||||
def REDS(mode):
|
|
||||||
"""Create lmdb for the REDS dataset, each image with a fixed size
|
|
||||||
GT: [3, 720, 1280], key: 000_00000000
|
|
||||||
LR: [3, 180, 320], key: 000_00000000
|
|
||||||
key: 000_00000000
|
|
||||||
|
|
||||||
flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
|
|
||||||
Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
|
|
||||||
Flow map is quantized by mmcv and saved in png format
|
|
||||||
"""
|
|
||||||
#### configurations
|
|
||||||
read_all_imgs = False # whether real all images to memory with multiprocessing
|
|
||||||
# Set False for use limited memory
|
|
||||||
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
|
||||||
if mode == 'train_sharp':
|
|
||||||
img_folder = '../../datasets/REDS/train_sharp'
|
|
||||||
lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
|
|
||||||
H_dst, W_dst = 720, 1280
|
|
||||||
elif mode == 'train_sharp_bicubic':
|
|
||||||
img_folder = '../../datasets/REDS/train_sharp_bicubic'
|
|
||||||
lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
|
||||||
H_dst, W_dst = 180, 320
|
|
||||||
elif mode == 'train_blur_bicubic':
|
|
||||||
img_folder = '../../datasets/REDS/train_blur_bicubic'
|
|
||||||
lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
|
|
||||||
H_dst, W_dst = 180, 320
|
|
||||||
elif mode == 'train_blur':
|
|
||||||
img_folder = '../../datasets/REDS/train_blur'
|
|
||||||
lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
|
|
||||||
H_dst, W_dst = 720, 1280
|
|
||||||
elif mode == 'train_blur_comp':
|
|
||||||
img_folder = '../../datasets/REDS/train_blur_comp'
|
|
||||||
lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
|
|
||||||
H_dst, W_dst = 720, 1280
|
|
||||||
elif mode == 'train_sharp_flowx4':
|
|
||||||
img_folder = '../../datasets/REDS/train_sharp_flowx4'
|
|
||||||
lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
|
|
||||||
H_dst, W_dst = 360, 320
|
|
||||||
n_thread = 40
|
|
||||||
########################################################
|
|
||||||
if not lmdb_save_path.endswith('.lmdb'):
|
|
||||||
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
|
||||||
if osp.exists(lmdb_save_path):
|
|
||||||
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
#### read all the image paths to a list
|
|
||||||
print('Reading image path list ...')
|
|
||||||
all_img_list = data_util._get_paths_from_images(img_folder)
|
|
||||||
keys = []
|
|
||||||
for img_path in all_img_list:
|
|
||||||
split_rlt = img_path.split('/')
|
|
||||||
folder = split_rlt[-2]
|
|
||||||
img_name = split_rlt[-1].split('.png')[0]
|
|
||||||
keys.append(folder + '_' + img_name)
|
|
||||||
|
|
||||||
if read_all_imgs:
|
|
||||||
#### read all images to memory (multiprocessing)
|
|
||||||
dataset = {} # store all image data. list cannot keep the order, use dict
|
|
||||||
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
|
||||||
pbar = util.ProgressBar(len(all_img_list))
|
|
||||||
|
|
||||||
def mycallback(arg):
|
|
||||||
'''get the image data and update pbar'''
|
|
||||||
key = arg[0]
|
|
||||||
dataset[key] = arg[1]
|
|
||||||
pbar.update('Reading {}'.format(key))
|
|
||||||
|
|
||||||
pool = Pool(n_thread)
|
|
||||||
for path, key in zip(all_img_list, keys):
|
|
||||||
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
|
||||||
pool.close()
|
|
||||||
pool.join()
|
|
||||||
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
|
||||||
|
|
||||||
#### create lmdb environment
|
|
||||||
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
|
||||||
print('data size per image is: ', data_size_per_img)
|
|
||||||
data_size = data_size_per_img * len(all_img_list)
|
|
||||||
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
|
||||||
|
|
||||||
#### write data to lmdb
|
|
||||||
pbar = util.ProgressBar(len(all_img_list))
|
|
||||||
txn = env.begin(write=True)
|
|
||||||
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
|
||||||
pbar.update('Write {}'.format(key))
|
|
||||||
key_byte = key.encode('ascii')
|
|
||||||
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
||||||
if 'flow' in mode:
|
|
||||||
H, W = data.shape
|
|
||||||
assert H == H_dst and W == W_dst, 'different shape.'
|
|
||||||
else:
|
|
||||||
H, W, C = data.shape
|
|
||||||
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
|
||||||
txn.put(key_byte, data)
|
|
||||||
if not read_all_imgs and idx % BATCH == 0:
|
|
||||||
txn.commit()
|
|
||||||
txn = env.begin(write=True)
|
|
||||||
txn.commit()
|
|
||||||
env.close()
|
|
||||||
print('Finish writing lmdb.')
|
|
||||||
|
|
||||||
#### create meta information
|
|
||||||
meta_info = {}
|
|
||||||
meta_info['name'] = 'REDS_{}_wval'.format(mode)
|
|
||||||
channel = 1 if 'flow' in mode else 3
|
|
||||||
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
|
||||||
meta_info['keys'] = keys
|
|
||||||
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
|
||||||
print('Finish creating lmdb meta info.')
|
|
||||||
|
|
||||||
|
|
||||||
def test_lmdb(dataroot, dataset='REDS'):
|
|
||||||
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
|
|
||||||
meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
|
|
||||||
print('Name: ', meta_info['name'])
|
|
||||||
print('Resolution: ', meta_info['resolution'])
|
|
||||||
print('# keys: ', len(meta_info['keys']))
|
|
||||||
# read one image
|
|
||||||
if dataset == 'vimeo90k':
|
|
||||||
key = '00001_0001_4'
|
|
||||||
else:
|
|
||||||
key = '000_00000000'
|
|
||||||
print('Reading {} for test.'.format(key))
|
|
||||||
with env.begin(write=False) as txn:
|
|
||||||
buf = txn.get(key.encode('ascii'))
|
|
||||||
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
|
||||||
C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
|
|
||||||
img = img_flat.reshape(H, W, C)
|
|
||||||
cv2.imwrite('test.png', img)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,22 +0,0 @@
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
writer = SummaryWriter("../experiments/recovered_tb")
|
|
||||||
f = open("../experiments/recovered_tb.txt", encoding="utf8")
|
|
||||||
console = f.readlines()
|
|
||||||
search_terms = [
|
|
||||||
("iter", ", iter: ", ", lr:"),
|
|
||||||
("l_g_total", " l_g_total: ", " switch_temperature:"),
|
|
||||||
("l_d_fake", "l_d_fake: ", " D_fake:")
|
|
||||||
]
|
|
||||||
iter = 0
|
|
||||||
for line in console:
|
|
||||||
if " - INFO: [epoch:" not in line:
|
|
||||||
continue
|
|
||||||
for name, start, end in search_terms:
|
|
||||||
val = line[line.find(start)+len(start):line.find(end)].replace(",", "")
|
|
||||||
if name == "iter":
|
|
||||||
iter = int(val)
|
|
||||||
else:
|
|
||||||
writer.add_scalar(name, float(val), iter)
|
|
||||||
writer.close()
|
|
|
@ -1,19 +0,0 @@
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4'
|
|
||||||
DIV2K(folder)
|
|
||||||
print('Finished.')
|
|
||||||
|
|
||||||
|
|
||||||
def DIV2K(path):
|
|
||||||
img_path_l = glob.glob(os.path.join(path, '*'))
|
|
||||||
for img_path in img_path_l:
|
|
||||||
new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
|
|
||||||
os.rename(img_path, new_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,83 +0,0 @@
|
||||||
import os.path as osp
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
import utils
|
|
||||||
import utils.options as option
|
|
||||||
import utils.util as util
|
|
||||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|
||||||
from data import create_dataset, create_dataloader
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
#### options
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
srg_analyze = False
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_psnr_approximator.yml')
|
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
|
||||||
opt = option.dict_to_nonedict(opt)
|
|
||||||
utils.util.loaded_options = opt
|
|
||||||
|
|
||||||
util.mkdirs(
|
|
||||||
(path for key, path in opt['path'].items()
|
|
||||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
|
||||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
|
||||||
screen=True, tofile=True)
|
|
||||||
logger = logging.getLogger('base')
|
|
||||||
logger.info(option.dict2str(opt))
|
|
||||||
|
|
||||||
#### Create test dataset and dataloader
|
|
||||||
test_loaders = []
|
|
||||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
|
||||||
dataset_opt['n_workers'] = 0
|
|
||||||
test_set = create_dataset(dataset_opt)
|
|
||||||
test_loader = create_dataloader(test_set, dataset_opt, opt)
|
|
||||||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
|
||||||
test_loaders.append(test_loader)
|
|
||||||
|
|
||||||
model = ExtensibleTrainer(opt)
|
|
||||||
for test_loader in test_loaders:
|
|
||||||
test_set_name = test_loader.dataset.opt['name']
|
|
||||||
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
|
||||||
test_start_time = time.time()
|
|
||||||
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
|
|
||||||
util.mkdir(dataset_dir)
|
|
||||||
|
|
||||||
dst_path = "F:\\playground"
|
|
||||||
[os.makedirs(osp.join(dst_path, str(i)), exist_ok=True) for i in range(10)]
|
|
||||||
|
|
||||||
corruptions = ['none', 'color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise',
|
|
||||||
'jpeg-medium', 'jpeg-broad', 'jpeg-normal', 'saturation', 'lq_resampling',
|
|
||||||
'lq_resampling4x']
|
|
||||||
c_counter = 0
|
|
||||||
test_set.corruptor.num_corrupts = 0
|
|
||||||
test_set.corruptor.random_corruptions = []
|
|
||||||
test_set.corruptor.fixed_corruptions = [corruptions[0]]
|
|
||||||
corruption_mse = [(0,0) for _ in corruptions]
|
|
||||||
|
|
||||||
tq = tqdm(test_loader)
|
|
||||||
batch_size = opt['datasets']['train']['batch_size']
|
|
||||||
for data in tq:
|
|
||||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
|
||||||
model.feed_data(data, need_GT=need_GT)
|
|
||||||
model.test()
|
|
||||||
est_psnr = torch.mean(model.eval_state['psnr_approximate'][0], dim=[1,2,3])
|
|
||||||
for i in range(est_psnr.shape[0]):
|
|
||||||
im_path = data['GT_path'][i]
|
|
||||||
torchvision.utils.save_image(model.eval_state['lq'][0][i], osp.join(dst_path, str(int(est_psnr[i]*10)), osp.basename(im_path)))
|
|
||||||
#shutil.copy(im_path, osp.join(dst_path, str(int(est_psnr[i]*10))))
|
|
||||||
|
|
||||||
last_mse, last_ctr = corruption_mse[c_counter % len(corruptions)]
|
|
||||||
corruption_mse[c_counter % len(corruptions)] = (last_mse + torch.sum(est_psnr).item(), last_ctr + 1)
|
|
||||||
c_counter += 1
|
|
||||||
test_set.corruptor.fixed_corruptions = [corruptions[c_counter % len(corruptions)]]
|
|
||||||
if c_counter % 100 == 0:
|
|
||||||
for i, (mse, ctr) in enumerate(corruption_mse):
|
|
||||||
print("%s: %f" % (corruptions[i], mse / (ctr * batch_size)))
|
|
|
@ -1,136 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
# note: all dct related functions are either exactly as or based on those
|
|
||||||
# at https://github.com/zh217/torch-dct
|
|
||||||
def dct(x, norm=None):
|
|
||||||
"""
|
|
||||||
Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
|
||||||
For the meaning of the parameter `norm`, see:
|
|
||||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
|
||||||
:param x: the input signal
|
|
||||||
:param norm: the normalization, None or 'ortho'
|
|
||||||
:return: the DCT-II of the signal over the last dimension
|
|
||||||
"""
|
|
||||||
x_shape = x.shape
|
|
||||||
N = x_shape[-1]
|
|
||||||
x = x.contiguous().view(-1, N)
|
|
||||||
|
|
||||||
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
|
||||||
|
|
||||||
Vc = torch.rfft(v, 1, onesided=False)
|
|
||||||
|
|
||||||
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
|
||||||
W_r = torch.cos(k)
|
|
||||||
W_i = torch.sin(k)
|
|
||||||
|
|
||||||
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
|
||||||
|
|
||||||
if norm == 'ortho':
|
|
||||||
V[:, 0] /= np.sqrt(N) * 2
|
|
||||||
V[:, 1:] /= np.sqrt(N / 2) * 2
|
|
||||||
|
|
||||||
V = 2 * V.view(*x_shape)
|
|
||||||
|
|
||||||
return V
|
|
||||||
|
|
||||||
def idct(X, norm=None):
|
|
||||||
"""
|
|
||||||
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
|
||||||
Our definition of idct is that idct(dct(x)) == x
|
|
||||||
For the meaning of the parameter `norm`, see:
|
|
||||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
|
||||||
:param X: the input signal
|
|
||||||
:param norm: the normalization, None or 'ortho'
|
|
||||||
:return: the inverse DCT-II of the signal over the last dimension
|
|
||||||
"""
|
|
||||||
|
|
||||||
x_shape = X.shape
|
|
||||||
N = x_shape[-1]
|
|
||||||
|
|
||||||
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
|
|
||||||
|
|
||||||
if norm == 'ortho':
|
|
||||||
X_v[:, 0] *= np.sqrt(N) * 2
|
|
||||||
X_v[:, 1:] *= np.sqrt(N / 2) * 2
|
|
||||||
|
|
||||||
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
|
|
||||||
W_r = torch.cos(k)
|
|
||||||
W_i = torch.sin(k)
|
|
||||||
|
|
||||||
V_t_r = X_v
|
|
||||||
V_t_r = V_t_r.to(device)
|
|
||||||
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
|
|
||||||
V_t_i = V_t_i.to(device)
|
|
||||||
|
|
||||||
V_r = V_t_r * W_r - V_t_i * W_i
|
|
||||||
V_i = V_t_r * W_i + V_t_i * W_r
|
|
||||||
|
|
||||||
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
|
|
||||||
|
|
||||||
v = torch.irfft(V, 1, onesided=False)
|
|
||||||
x = v.new_zeros(v.shape)
|
|
||||||
x[:, ::2] += v[:, :N - (N // 2)]
|
|
||||||
x[:, 1::2] += v.flip([1])[:, :N // 2]
|
|
||||||
|
|
||||||
return x.view(*x_shape)
|
|
||||||
|
|
||||||
def dct_2d(x, norm=None):
|
|
||||||
"""
|
|
||||||
2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
|
||||||
For the meaning of the parameter `norm`, see:
|
|
||||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
|
||||||
:param x: the input signal
|
|
||||||
:param norm: the normalization, None or 'ortho'
|
|
||||||
:return: the DCT-II of the signal over the last 2 dimensions
|
|
||||||
"""
|
|
||||||
X1 = dct(x, norm=norm)
|
|
||||||
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
|
||||||
return X2.transpose(-1, -2)
|
|
||||||
|
|
||||||
def idct_2d(X, norm=None):
|
|
||||||
"""
|
|
||||||
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
|
||||||
Our definition of idct is that idct_2d(dct_2d(x)) == x
|
|
||||||
For the meaning of the parameter `norm`, see:
|
|
||||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
|
||||||
:param X: the input signal
|
|
||||||
:param norm: the normalization, None or 'ortho'
|
|
||||||
:return: the DCT-II of the signal over the last 2 dimensions
|
|
||||||
"""
|
|
||||||
x1 = idct(X, norm=norm)
|
|
||||||
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
|
||||||
return x2.transpose(-1, -2)
|
|
||||||
|
|
||||||
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False):
|
|
||||||
"""
|
|
||||||
source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78
|
|
||||||
"""
|
|
||||||
patch_H, patch_W = patch_shape[0], patch_shape[1]
|
|
||||||
if(img.size(2) < patch_H):
|
|
||||||
num_padded_H_Top = (patch_H - img.size(2))//2
|
|
||||||
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
|
|
||||||
padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0)
|
|
||||||
img = padding_H(img)
|
|
||||||
if(img.size(3) < patch_W):
|
|
||||||
num_padded_W_Left = (patch_W - img.size(3))//2
|
|
||||||
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
|
|
||||||
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0)
|
|
||||||
img = padding_W(img)
|
|
||||||
step_int = [0, 0]
|
|
||||||
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
|
|
||||||
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
|
|
||||||
patches_fold_H = img.unfold(2, patch_H, step_int[0])
|
|
||||||
if((img.size(2) - patch_H) % step_int[0] != 0):
|
|
||||||
patches_fold_H = torch.cat((patches_fold_H,
|
|
||||||
img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2)
|
|
||||||
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
|
|
||||||
if((img.size(3) - patch_W) % step_int[1] != 0):
|
|
||||||
patches_fold_HW = torch.cat((patches_fold_HW,
|
|
||||||
patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3)
|
|
||||||
patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5)
|
|
||||||
patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W)
|
|
||||||
if(batch_first):
|
|
||||||
patches = patches.permute(1, 0, 2, 3, 4)
|
|
||||||
return patches
|
|
Loading…
Reference in New Issue
Block a user