Clean up codebase
Remove stuff that I'm likely not going to use again (or generally failed experiments)
This commit is contained in:
parent
4d1a42e944
commit
55b58fb67f
|
@ -1,507 +0,0 @@
|
|||
import math
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from functools import wraps, partial
|
||||
from math import floor
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from kornia import augmentation as augs
|
||||
from kornia import filters, color
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
# helper functions
|
||||
from trainer.networks import register_model, create_model
|
||||
|
||||
|
||||
def identity(t):
|
||||
return t
|
||||
|
||||
def default(val, def_val):
|
||||
return def_val if val is None else val
|
||||
|
||||
def rand_true(prob):
|
||||
return random.random() < prob
|
||||
|
||||
def singleton(cache_key):
|
||||
def inner_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
instance = getattr(self, cache_key)
|
||||
if instance is not None:
|
||||
return instance
|
||||
|
||||
instance = fn(self, *args, **kwargs)
|
||||
setattr(self, cache_key, instance)
|
||||
return instance
|
||||
return wrapper
|
||||
return inner_fn
|
||||
|
||||
def get_module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def set_requires_grad(model, val):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = val
|
||||
|
||||
def cutout_coordinates(image, ratio_range = (0.6, 0.8)):
|
||||
_, _, orig_h, orig_w = image.shape
|
||||
|
||||
ratio_lo, ratio_hi = ratio_range
|
||||
random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo)
|
||||
w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h)
|
||||
coor_x = floor((orig_w - w) * random.random())
|
||||
coor_y = floor((orig_h - h) * random.random())
|
||||
return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio
|
||||
|
||||
def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
|
||||
shape = image.shape
|
||||
output_size = default(output_size, shape[2:])
|
||||
(y0, y1), (x0, x1) = coordinates
|
||||
cutout_image = image[:, :, y0:y1, x0:x1]
|
||||
return F.interpolate(cutout_image, size = output_size, mode = mode)
|
||||
|
||||
def scale_coords(coords, scale):
|
||||
output = [[0,0],[0,0]]
|
||||
for j in range(2):
|
||||
for k in range(2):
|
||||
output[j][k] = int(coords[j][k] / scale)
|
||||
return output
|
||||
|
||||
def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'):
|
||||
blank = torch.zeros_like(image)
|
||||
coordinates = scale_coords(coordinates, scale_reduction)
|
||||
(y0, y1), (x0, x1) = coordinates
|
||||
orig_cutout_shape = (y1-y0, x1-x0)
|
||||
if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0:
|
||||
return None
|
||||
|
||||
un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode)
|
||||
blank[:,:,y0:y1,x0:x1] = un_resized_img
|
||||
return blank
|
||||
|
||||
def compute_shared_coords(coords1, coords2, scale_reduction):
|
||||
(y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction)
|
||||
(y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction)
|
||||
shared = ((max(y1_t, y2_t), min(y1_b, y2_b)),
|
||||
(max(x1_l, x2_l), min(x1_r, x2_r)))
|
||||
for s in shared:
|
||||
if s == 0:
|
||||
return None
|
||||
return shared
|
||||
|
||||
def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode):
|
||||
# Unflip the pixel projections
|
||||
proj_pixel_one = flip_image_one_fn(proj_pixel_one)
|
||||
proj_pixel_two = flip_image_two_fn(proj_pixel_two)
|
||||
|
||||
# Undo the cutout and resize, taking into account the scale reduction applied by the encoder.
|
||||
scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1]
|
||||
proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction,
|
||||
mode=interp_mode)
|
||||
proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction,
|
||||
mode=interp_mode)
|
||||
if proj_pixel_one is None or proj_pixel_two is None:
|
||||
print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!")
|
||||
return None
|
||||
|
||||
# Compute the shared coordinates for the two cutouts:
|
||||
shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction)
|
||||
if shared_coords is None:
|
||||
print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..")
|
||||
return None
|
||||
(yt, yb), (xl, xr) = shared_coords
|
||||
|
||||
return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr]
|
||||
|
||||
# augmentation utils
|
||||
|
||||
class RandomApply(nn.Module):
|
||||
def __init__(self, fn, p):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.p = p
|
||||
def forward(self, x):
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
return self.fn(x)
|
||||
|
||||
# exponential moving average
|
||||
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
def update_moving_average(ema_updater, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = ema_updater.update_average(old_weight, up_weight)
|
||||
|
||||
# loss fn
|
||||
|
||||
def loss_fn(x, y):
|
||||
x = F.normalize(x, dim=-1, p=2)
|
||||
y = F.normalize(y, dim=-1, p=2)
|
||||
return 2 - 2 * (x * y).sum(dim=-1)
|
||||
|
||||
# classes
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, chan, chan_out = 256, inner_dim = 2048):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(chan, inner_dim),
|
||||
nn.BatchNorm1d(inner_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(inner_dim, chan_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class ConvMLP(nn.Module):
|
||||
def __init__(self, chan, chan_out = 256, inner_dim = 2048):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, inner_dim, 1),
|
||||
nn.BatchNorm2d(inner_dim),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(inner_dim, chan_out, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class PPM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
chan,
|
||||
num_layers = 1,
|
||||
gamma = 2):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
|
||||
if num_layers == 0:
|
||||
self.transform_net = nn.Identity()
|
||||
elif num_layers == 1:
|
||||
self.transform_net = nn.Conv2d(chan, chan, 1)
|
||||
elif num_layers == 2:
|
||||
self.transform_net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan, 1),
|
||||
nn.BatchNorm2d(chan),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
else:
|
||||
raise ValueError('num_layers must be one of 0, 1, or 2')
|
||||
|
||||
def forward(self, x):
|
||||
xi = x[:, :, :, :, None, None]
|
||||
xj = x[:, :, None, None, :, :]
|
||||
similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma
|
||||
|
||||
transform_out = self.transform_net(x)
|
||||
out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out)
|
||||
return out
|
||||
|
||||
# a wrapper class for the base neural network
|
||||
# will manage the interception of the hidden layer output
|
||||
# and pipe it into the projecter and predictor nets
|
||||
|
||||
class NetWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
net,
|
||||
instance_projection_size,
|
||||
instance_projection_hidden_size,
|
||||
pix_projection_size,
|
||||
pix_projection_hidden_size,
|
||||
layer_pixel = -2,
|
||||
layer_instance = -2
|
||||
):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.layer_pixel = layer_pixel
|
||||
self.layer_instance = layer_instance
|
||||
|
||||
self.pixel_projector = None
|
||||
self.instance_projector = None
|
||||
|
||||
self.instance_projection_size = instance_projection_size
|
||||
self.instance_projection_hidden_size = instance_projection_hidden_size
|
||||
self.pix_projection_size = pix_projection_size
|
||||
self.pix_projection_hidden_size = pix_projection_hidden_size
|
||||
|
||||
self.hidden_pixel = None
|
||||
self.hidden_instance = None
|
||||
self.hook_registered = False
|
||||
|
||||
def _find_layer(self, layer_id):
|
||||
if type(layer_id) == str:
|
||||
modules = dict([*self.net.named_modules()])
|
||||
return modules.get(layer_id, None)
|
||||
elif type(layer_id) == int:
|
||||
children = [*self.net.children()]
|
||||
return children[layer_id]
|
||||
return None
|
||||
|
||||
def _hook(self, attr_name, _, __, output):
|
||||
setattr(self, attr_name, output)
|
||||
|
||||
def _register_hook(self):
|
||||
pixel_layer = self._find_layer(self.layer_pixel)
|
||||
instance_layer = self._find_layer(self.layer_instance)
|
||||
|
||||
assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found'
|
||||
assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found'
|
||||
|
||||
pixel_layer.register_forward_hook(partial(self._hook, 'hidden_pixel'))
|
||||
instance_layer.register_forward_hook(partial(self._hook, 'hidden_instance'))
|
||||
self.hook_registered = True
|
||||
|
||||
@singleton('pixel_projector')
|
||||
def _get_pixel_projector(self, hidden):
|
||||
_, dim, *_ = hidden.shape
|
||||
projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
@singleton('instance_projector')
|
||||
def _get_instance_projector(self, hidden):
|
||||
_, dim = hidden.shape
|
||||
projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_representation(self, x):
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
_ = self.net(x)
|
||||
hidden_pixel = self.hidden_pixel
|
||||
hidden_instance = self.hidden_instance
|
||||
self.hidden_pixel = None
|
||||
self.hidden_instance = None
|
||||
assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output'
|
||||
assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output'
|
||||
return hidden_pixel, hidden_instance
|
||||
|
||||
def forward(self, x):
|
||||
pixel_representation, instance_representation = self.get_representation(x)
|
||||
instance_representation = instance_representation.flatten(1)
|
||||
|
||||
pixel_projector = self._get_pixel_projector(pixel_representation)
|
||||
instance_projector = self._get_instance_projector(instance_representation)
|
||||
|
||||
pixel_projection = pixel_projector(pixel_representation)
|
||||
instance_projection = instance_projector(instance_representation)
|
||||
return pixel_projection, instance_projection
|
||||
|
||||
# main class
|
||||
class PixelCL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
image_size,
|
||||
hidden_layer_pixel = -2,
|
||||
hidden_layer_instance = -2,
|
||||
instance_projection_size = 256,
|
||||
instance_projection_hidden_size = 2048,
|
||||
pix_projection_size = 256,
|
||||
pix_projection_hidden_size = 2048,
|
||||
augment_fn = None,
|
||||
augment_fn2 = None,
|
||||
prob_rand_hflip = 0.25,
|
||||
moving_average_decay = 0.99,
|
||||
ppm_num_layers = 1,
|
||||
ppm_gamma = 2,
|
||||
distance_thres = 0.7,
|
||||
similarity_temperature = 0.3,
|
||||
cutout_ratio_range = (0.6, 0.8),
|
||||
cutout_interpolate_mode = 'nearest',
|
||||
coord_cutout_interpolate_mode = 'bilinear',
|
||||
max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
DEFAULT_AUG = nn.Sequential(
|
||||
RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||
augs.RandomSolarize(p=0.5),
|
||||
# Normalize left out because it should be done at the model level.
|
||||
)
|
||||
|
||||
self.augment1 = default(augment_fn, DEFAULT_AUG)
|
||||
self.augment2 = default(augment_fn2, self.augment1)
|
||||
self.prob_rand_hflip = prob_rand_hflip
|
||||
|
||||
self.online_encoder = NetWrapper(
|
||||
net = net,
|
||||
instance_projection_size = instance_projection_size,
|
||||
instance_projection_hidden_size = instance_projection_hidden_size,
|
||||
pix_projection_size = pix_projection_size,
|
||||
pix_projection_hidden_size = pix_projection_hidden_size,
|
||||
layer_pixel = hidden_layer_pixel,
|
||||
layer_instance = hidden_layer_instance
|
||||
)
|
||||
|
||||
self.target_encoder = None
|
||||
self.target_ema_updater = EMA(moving_average_decay)
|
||||
|
||||
self.distance_thres = distance_thres
|
||||
self.similarity_temperature = similarity_temperature
|
||||
|
||||
# This requirement is due to the way that these are processed, not a hard requirement.
|
||||
assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
|
||||
self.max_latent_dim = max_latent_dim
|
||||
|
||||
self.propagate_pixels = PPM(
|
||||
chan = pix_projection_size,
|
||||
num_layers = ppm_num_layers,
|
||||
gamma = ppm_gamma
|
||||
)
|
||||
|
||||
self.cutout_ratio_range = cutout_ratio_range
|
||||
self.cutout_interpolate_mode = cutout_interpolate_mode
|
||||
self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
|
||||
|
||||
# instance level predictor
|
||||
self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size)
|
||||
|
||||
# get device of network and make wrapper same device
|
||||
device = get_module_device(net)
|
||||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
||||
|
||||
@singleton('target_encoder')
|
||||
def _get_target_encoder(self):
|
||||
target_encoder = copy.deepcopy(self.online_encoder)
|
||||
set_requires_grad(target_encoder, False)
|
||||
return target_encoder
|
||||
|
||||
def reset_moving_average(self):
|
||||
del self.target_encoder
|
||||
self.target_encoder = None
|
||||
|
||||
def update_moving_average(self):
|
||||
assert self.target_encoder is not None, 'target encoder has not been created yet'
|
||||
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
||||
|
||||
def forward(self, x):
|
||||
shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip
|
||||
|
||||
rand_flip_fn = lambda t: torch.flip(t, dims = (-1,))
|
||||
|
||||
flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip)
|
||||
flip_image_one_fn = rand_flip_fn if flip_image_one else identity
|
||||
flip_image_two_fn = rand_flip_fn if flip_image_two else identity
|
||||
|
||||
cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range)
|
||||
cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range)
|
||||
|
||||
image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode)
|
||||
image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode)
|
||||
|
||||
image_one_cutout = flip_image_one_fn(image_one_cutout)
|
||||
image_two_cutout = flip_image_two_fn(image_two_cutout)
|
||||
|
||||
image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout)
|
||||
|
||||
self.aug1 = image_one_cutout.detach().clone()
|
||||
self.aug2 = image_two_cutout.detach().clone()
|
||||
|
||||
proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
|
||||
proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
|
||||
|
||||
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
|
||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||
if proj_pixel_one is None or proj_pixel_two is None:
|
||||
positive_pixel_pairs = 0
|
||||
else:
|
||||
positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2]
|
||||
|
||||
with torch.no_grad():
|
||||
target_encoder = self._get_target_encoder()
|
||||
target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
|
||||
target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
|
||||
target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one,
|
||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||
|
||||
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
|
||||
b, c, pp_h, pp_w = proj_pixel_one.shape
|
||||
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
|
||||
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
|
||||
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
|
||||
extracted = []
|
||||
for l in latents:
|
||||
l = l.reshape(b, c, pp_h * pp_w)
|
||||
l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
|
||||
# For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square".
|
||||
# Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards
|
||||
# to the original image structure.
|
||||
sqdim = int(math.sqrt(self.max_latent_dim))
|
||||
extracted.append(l.reshape(b, c, sqdim, sqdim))
|
||||
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
|
||||
|
||||
# flatten all the pixel projections
|
||||
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
|
||||
target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
|
||||
|
||||
# get instance level loss
|
||||
pred_instance_one = self.online_predictor(proj_instance_one)
|
||||
pred_instance_two = self.online_predictor(proj_instance_two)
|
||||
loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
|
||||
loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
|
||||
instance_loss = (loss_instance_one + loss_instance_two).mean()
|
||||
|
||||
if positive_pixel_pairs == 0:
|
||||
return instance_loss, 0
|
||||
|
||||
# calculate pix pro loss
|
||||
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
|
||||
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
|
||||
|
||||
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
|
||||
|
||||
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
|
||||
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
|
||||
|
||||
loss_pixpro_one_two = - propagated_similarity_one_two.mean()
|
||||
loss_pixpro_two_one = - propagated_similarity_two_one.mean()
|
||||
|
||||
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
|
||||
|
||||
return instance_loss, pix_loss, positive_pixel_pairs
|
||||
|
||||
# Allows visualizing what the augmentor is up to.
|
||||
def visual_dbg(self, step, path):
|
||||
if not hasattr(self, 'aug1'):
|
||||
return
|
||||
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
|
||||
|
||||
|
||||
@register_model
|
||||
def register_pixel_contrastive_learner(opt_net, opt):
|
||||
subnet = create_model(opt, opt_net['subnet'])
|
||||
kwargs = opt_net['kwargs']
|
||||
if 'subnet_pretrain_path' in opt_net.keys():
|
||||
sd = torch.load(opt_net['subnet_pretrain_path'])
|
||||
subnet.load_state_dict(sd, strict=False)
|
||||
return PixelCL(subnet, **kwargs)
|
|
@ -1,152 +0,0 @@
|
|||
# Resnet implementation that adds a u-net style up-conversion component to output values at a
|
||||
# specified pixel density.
|
||||
#
|
||||
# The downsampling part of the network is compatible with the built-in torch resnet for use in
|
||||
# transfer learning.
|
||||
#
|
||||
# Only resnet50 currently supported.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
import torchvision
|
||||
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class ReverseBottleneck(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, groups=1, passthrough=False,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
self.passthrough = passthrough
|
||||
if passthrough:
|
||||
self.integrate = conv1x1(inplanes*2, inplanes)
|
||||
self.bn_integrate = norm_layer(inplanes)
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.residual_upsample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
conv1x1(width, width),
|
||||
norm_layer(width),
|
||||
)
|
||||
self.conv3 = conv1x1(width, planes)
|
||||
self.bn3 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.upsample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
conv1x1(inplanes, planes),
|
||||
norm_layer(planes),
|
||||
)
|
||||
|
||||
def forward(self, x, passthrough=None):
|
||||
if self.passthrough:
|
||||
x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1)))
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.residual_upsample(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
identity = self.upsample(x)
|
||||
|
||||
out = out + identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class UResNet50(torchvision.models.resnet.ResNet):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, out_dim=128):
|
||||
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
||||
replace_stride_with_dilation, norm_layer)
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
'''
|
||||
# For reference:
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
'''
|
||||
uplayers = []
|
||||
inplanes = 2048
|
||||
first = True
|
||||
for i in range(2):
|
||||
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
|
||||
inplanes = inplanes // 2
|
||||
first = False
|
||||
self.uplayers = nn.ModuleList(uplayers)
|
||||
self.tail = nn.Sequential(conv1x1(1024, 512),
|
||||
norm_layer(512),
|
||||
nn.ReLU(),
|
||||
conv3x3(512, 512),
|
||||
norm_layer(512),
|
||||
nn.ReLU(),
|
||||
conv1x1(512, out_dim))
|
||||
|
||||
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
||||
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl,
|
||||
# except using checkpoints on the body conv layers.
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x1 = checkpoint(self.layer1, x)
|
||||
x2 = checkpoint(self.layer2, x1)
|
||||
x3 = checkpoint(self.layer3, x2)
|
||||
x4 = checkpoint(self.layer4, x3)
|
||||
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
|
||||
|
||||
x = checkpoint(self.uplayers[0], x4)
|
||||
x = checkpoint(self.uplayers[1], x, x3)
|
||||
#x = checkpoint(self.uplayers[2], x, x2)
|
||||
#x = checkpoint(self.uplayers[3], x, x1)
|
||||
|
||||
return checkpoint(self.tail, torch.cat([x, x2], dim=1))
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_u_resnet50(opt_net, opt):
|
||||
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = UResNet50(Bottleneck, [3,4,6,3])
|
||||
samp = torch.rand(1,3,224,224)
|
||||
model(samp)
|
||||
# For pixpro: attach to "tail.3"
|
|
@ -1,87 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
import torchvision
|
||||
|
||||
from models.arch_util import ConvBnRelu
|
||||
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class UResNet50_2(torchvision.models.resnet.ResNet):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, out_dim=128):
|
||||
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
||||
replace_stride_with_dilation, norm_layer)
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self.level_conv = ConvBnRelu(3, 64)
|
||||
'''
|
||||
# For reference:
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
'''
|
||||
uplayers = []
|
||||
inplanes = 2048
|
||||
first = True
|
||||
div = [2,2,2,4,1]
|
||||
for i in range(5):
|
||||
uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first))
|
||||
inplanes = inplanes // div[i]
|
||||
first = False
|
||||
self.uplayers = nn.ModuleList(uplayers)
|
||||
self.tail = nn.Sequential(conv3x3(128, 64),
|
||||
norm_layer(64),
|
||||
nn.ReLU(),
|
||||
conv1x1(64, out_dim))
|
||||
|
||||
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
||||
|
||||
|
||||
def _forward_impl(self, x):
|
||||
level = self.level_conv(x)
|
||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
||||
x = self.maxpool(x0)
|
||||
|
||||
x1 = checkpoint(self.layer1, x)
|
||||
x2 = checkpoint(self.layer2, x1)
|
||||
x3 = checkpoint(self.layer3, x2)
|
||||
x4 = checkpoint(self.layer4, x3)
|
||||
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
|
||||
|
||||
x = checkpoint(self.uplayers[0], x4)
|
||||
x = checkpoint(self.uplayers[1], x, x3)
|
||||
x = checkpoint(self.uplayers[2], x, x2)
|
||||
x = checkpoint(self.uplayers[3], x, x1)
|
||||
x = checkpoint(self.uplayers[4], x, x0)
|
||||
|
||||
return checkpoint(self.tail, torch.cat([x, level], dim=1))
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_u_resnet50_2(opt_net, opt):
|
||||
model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = UResNet50_2(Bottleneck, [3,4,6,3])
|
||||
samp = torch.rand(1,3,224,224)
|
||||
y = model(samp)
|
||||
print(y.shape)
|
||||
# For pixpro: attach to "tail.3"
|
|
@ -1,86 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
import torchvision
|
||||
|
||||
from models.arch_util import ConvBnRelu
|
||||
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class UResNet50_3(torchvision.models.resnet.ResNet):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, out_dim=128):
|
||||
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
||||
replace_stride_with_dilation, norm_layer)
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
'''
|
||||
# For reference:
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
'''
|
||||
uplayers = []
|
||||
inplanes = 2048
|
||||
first = True
|
||||
for i in range(3):
|
||||
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
|
||||
inplanes = inplanes // 2
|
||||
first = False
|
||||
self.uplayers = nn.ModuleList(uplayers)
|
||||
|
||||
# These two variables are separated out and renamed so that I can re-use parameters from a pretrained resnet_unet2.
|
||||
self.last_uplayer = ReverseBottleneck(256, 128, norm_layer=norm_layer, passthrough=True)
|
||||
self.tail3 = nn.Sequential(conv1x1(192, 128),
|
||||
norm_layer(128),
|
||||
nn.ReLU(),
|
||||
conv1x1(128, out_dim))
|
||||
|
||||
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
||||
|
||||
|
||||
def _forward_impl(self, x):
|
||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
||||
x = self.maxpool(x0)
|
||||
|
||||
x1 = checkpoint(self.layer1, x)
|
||||
x2 = checkpoint(self.layer2, x1)
|
||||
x3 = checkpoint(self.layer3, x2)
|
||||
x4 = checkpoint(self.layer4, x3)
|
||||
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
|
||||
|
||||
x = checkpoint(self.uplayers[0], x4)
|
||||
x = checkpoint(self.uplayers[1], x, x3)
|
||||
x = checkpoint(self.uplayers[2], x, x2)
|
||||
x = checkpoint(self.last_uplayer, x, x1)
|
||||
|
||||
return checkpoint(self.tail3, torch.cat([x, x0], dim=1))
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_u_resnet50_3(opt_net, opt):
|
||||
model = UResNet50_3(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = UResNet50_3(Bottleneck, [3,4,6,3])
|
||||
samp = torch.rand(1,3,224,224)
|
||||
y = model(samp)
|
||||
print(y.shape)
|
||||
# For pixpro: attach to "tail.3"
|
|
@ -1,9 +0,0 @@
|
|||
from models.styled_sr.discriminator import StyleSrGanDivergenceLoss
|
||||
|
||||
|
||||
def create_stylesr_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
if type == 'style_sr_gan_divergence_loss':
|
||||
return StyleSrGanDivergenceLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
|
@ -1,344 +0,0 @@
|
|||
# Heavily based on the lucidrains stylegan2 discriminator implementation.
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
from math import log2
|
||||
from random import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch.autograd import grad as torch_grad
|
||||
import trainer.losses as L
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
|
||||
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.filters = filters
|
||||
self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
|
||||
leaky_relu(),
|
||||
TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
|
||||
leaky_relu()
|
||||
)
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
Blur(),
|
||||
TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
|
||||
) if downsample else None
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv_res(x)
|
||||
x = self.net(x)
|
||||
if exists(self.downsample):
|
||||
x = self.downsample(x)
|
||||
x = (x + res) * (1 / math.sqrt(2))
|
||||
return x
|
||||
|
||||
|
||||
class StyleSrDiscriminator(nn.Module):
|
||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
|
||||
transfer_mode=False):
|
||||
super().__init__()
|
||||
num_layers = int(log2(image_size) - 1)
|
||||
|
||||
blocks = []
|
||||
filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
|
||||
|
||||
set_fmap_max = partial(min, fmap_max)
|
||||
filters = list(map(set_fmap_max, filters))
|
||||
chan_in_out = list(zip(filters[:-1], filters[1:]))
|
||||
|
||||
blocks = []
|
||||
attn_blocks = []
|
||||
quantize_blocks = []
|
||||
|
||||
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
|
||||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(chan_in_out) - 1)
|
||||
|
||||
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
|
||||
blocks.append(block)
|
||||
|
||||
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
|
||||
|
||||
attn_blocks.append(attn_fn)
|
||||
|
||||
if quantize:
|
||||
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
|
||||
quantize_blocks.append(quantize_fn)
|
||||
else:
|
||||
quantize_blocks.append(None)
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.attn_blocks = nn.ModuleList(attn_blocks)
|
||||
self.quantize_blocks = nn.ModuleList(quantize_blocks)
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
chan_last = filters[-1]
|
||||
latent_dim = 2 * 2 * chan_last
|
||||
|
||||
self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
|
||||
self.flatten = nn.Flatten()
|
||||
if mlp:
|
||||
self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
|
||||
leaky_relu(),
|
||||
TransferLinear(100, 1, transfer_mode=transfer_mode))
|
||||
else:
|
||||
self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
for p in self.parameters():
|
||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
def forward(self, x):
|
||||
b, *_ = x.shape
|
||||
|
||||
quantize_loss = torch.zeros(1).to(x)
|
||||
|
||||
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
|
||||
if self.do_checkpointing:
|
||||
x = checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
if exists(attn_block):
|
||||
x = attn_block(x)
|
||||
|
||||
if exists(q_block):
|
||||
x, _, loss = q_block(x)
|
||||
quantize_loss += loss
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = self.flatten(x)
|
||||
x = self.to_logit(x)
|
||||
if exists(q_block):
|
||||
return x.squeeze(), quantize_loss
|
||||
else:
|
||||
return x.squeeze()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {TransferConv2d, TransferLinear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
# Configures the network as partially pre-trained. This means:
|
||||
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
|
||||
# 2) The head (linear layers) will also have their weights re-initialized
|
||||
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
|
||||
# These settings will be applied after the weights have been loaded (network_loaded())
|
||||
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
|
||||
self.bypass_blocks = bypass_blocks
|
||||
self.num_blocks = num_blocks
|
||||
self.frozen_until_step = frozen_until_step
|
||||
|
||||
# Called after the network weights are loaded.
|
||||
def network_loaded(self):
|
||||
if not hasattr(self, 'frozen_until_step'):
|
||||
return
|
||||
|
||||
if self.bypass_blocks > 0:
|
||||
self.blocks = self.blocks[self.bypass_blocks:]
|
||||
self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device)
|
||||
|
||||
reset_blocks = [self.to_logit]
|
||||
for i in range(self.num_blocks):
|
||||
reset_blocks.append(self.blocks[i])
|
||||
for bl in reset_blocks:
|
||||
for m in bl.modules():
|
||||
if type(m) in {TransferConv2d, TransferLinear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
for p in m.parameters(recurse=True):
|
||||
p._NEW_BLOCK = True
|
||||
for p in self.parameters():
|
||||
if not hasattr(p, '_NEW_BLOCK'):
|
||||
p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step
|
||||
|
||||
|
||||
# helper classes
|
||||
def DiffAugment(x, types=[]):
|
||||
for p in types:
|
||||
for f in AUGMENT_FNS[p]:
|
||||
x = f(x)
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
def random_hflip(tensor, prob):
|
||||
if prob > random():
|
||||
return tensor
|
||||
return torch.flip(tensor, dims=(3,))
|
||||
|
||||
|
||||
def rand_brightness(x):
|
||||
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
||||
return x
|
||||
|
||||
|
||||
def rand_saturation(x):
|
||||
x_mean = x.mean(dim=1, keepdim=True)
|
||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
||||
return x
|
||||
|
||||
|
||||
def rand_contrast(x):
|
||||
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
||||
return x
|
||||
|
||||
|
||||
def rand_translation(x, ratio=0.125):
|
||||
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
||||
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
||||
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
||||
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
||||
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
||||
)
|
||||
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
||||
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
||||
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
||||
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def rand_cutout(x, ratio=0.5):
|
||||
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
||||
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
||||
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
||||
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
||||
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
||||
)
|
||||
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
||||
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
||||
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
||||
mask[grid_batch, grid_x, grid_y] = 0
|
||||
x = x * mask.unsqueeze(1)
|
||||
return x
|
||||
|
||||
|
||||
AUGMENT_FNS = {
|
||||
'color': [rand_brightness, rand_saturation, rand_contrast],
|
||||
'translation': [rand_translation],
|
||||
'cutout': [rand_cutout],
|
||||
}
|
||||
|
||||
|
||||
class DiscAugmentor(nn.Module):
|
||||
def __init__(self, D, image_size, types, prob):
|
||||
super().__init__()
|
||||
self.D = D
|
||||
self.prob = prob
|
||||
self.types = types
|
||||
|
||||
def forward(self, images, real_images=False):
|
||||
if random() < self.prob:
|
||||
images = random_hflip(images, prob=0.5)
|
||||
images = DiffAugment(images, types=self.types)
|
||||
|
||||
if real_images:
|
||||
self.hq_aug = images.detach().clone()
|
||||
else:
|
||||
self.gen_aug = images.detach().clone()
|
||||
|
||||
# Save away for use elsewhere (e.g. unet loss)
|
||||
self.aug_images = images
|
||||
|
||||
return self.D(images)
|
||||
|
||||
def network_loaded(self):
|
||||
self.D.network_loaded()
|
||||
|
||||
# Allows visualizing what the augmentor is up to.
|
||||
def visual_dbg(self, step, path):
|
||||
torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step)))
|
||||
torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step)))
|
||||
|
||||
|
||||
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
||||
if fp16:
|
||||
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
|
||||
scaled_loss.backward(**kwargs)
|
||||
else:
|
||||
loss.backward(**kwargs)
|
||||
|
||||
|
||||
def gradient_penalty(images, output, weight=10, return_structured_grads=False):
|
||||
batch_size = images.shape[0]
|
||||
gradients = torch_grad(outputs=output, inputs=images,
|
||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
flat_grad = gradients.reshape(batch_size, -1)
|
||||
penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
|
||||
if return_structured_grads:
|
||||
return penalty, gradients
|
||||
else:
|
||||
return penalty
|
||||
|
||||
|
||||
class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.real = opt['real']
|
||||
self.fake = opt['fake']
|
||||
self.discriminator = opt['discriminator']
|
||||
self.for_gen = opt['gen_loss']
|
||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
|
||||
def forward(self, net, state):
|
||||
real_input = state[self.real]
|
||||
fake_input = state[self.fake]
|
||||
if self.noise != 0:
|
||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||
|
||||
D = self.env['discriminators'][self.discriminator]
|
||||
fake = D(fake_input, real_images=False)
|
||||
if self.for_gen:
|
||||
return fake.mean()
|
||||
else:
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real = D(real_input, real_images=True)
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
||||
real_input.requires_grad_(requires_grad=False)
|
||||
return divergence_loss
|
||||
|
||||
|
||||
@register_model
|
||||
def register_styledsr_discriminator(opt_net, opt):
|
||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||
disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
|
||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
|
||||
quantize=opt_get(opt_net, ['quantize'], False),
|
||||
mlp=opt_get(opt_net, ['mlp_head'], True),
|
||||
transfer_mode=opt_get(opt_net, ['transfer_mode'], False)
|
||||
)
|
||||
if 'use_partial_pretrained' in opt_net.keys():
|
||||
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
|
||||
return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
|
@ -1,199 +0,0 @@
|
|||
from random import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.arch_util import kaiming_init
|
||||
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock
|
||||
from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
def rrdb_init_weights(module, scale=1):
|
||||
for m in module.modules():
|
||||
if isinstance(m, TransferConv2d):
|
||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||
m.weight.data *= scale
|
||||
elif isinstance(m, TransferLinear):
|
||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||
m.weight.data *= scale
|
||||
|
||||
|
||||
class EncoderRRDB(nn.Module):
|
||||
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False):
|
||||
super(EncoderRRDB, self).__init__()
|
||||
for i in range(5):
|
||||
out_channels = output_channels if i == 4 else growth_channels
|
||||
self.add_module(
|
||||
f'conv{i+1}',
|
||||
TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
for i in range(5):
|
||||
rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5
|
||||
|
||||
|
||||
class StyledSrEncoder(nn.Module):
|
||||
def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False):
|
||||
super().__init__()
|
||||
# Current assumes fea_out=256.
|
||||
self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode)
|
||||
self.rrdbs = nn.ModuleList([
|
||||
EncoderRRDB(32, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(64, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(96, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(128, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(160, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(192, transfer_mode=transfer_mode),
|
||||
EncoderRRDB(224, transfer_mode=transfer_mode)])
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.initial_conv(x)
|
||||
for rrdb in self.rrdbs:
|
||||
fea = torch.cat([fea, checkpoint(rrdb, fea)], dim=1)
|
||||
return fea
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False):
|
||||
super().__init__()
|
||||
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
|
||||
self.image_size = image_size
|
||||
self.scale = 2 ** upsample_levels
|
||||
self.latent_dim = latent_dim
|
||||
self.num_layers = total_levels
|
||||
self.transfer_mode = transfer_mode
|
||||
filters = [
|
||||
512, # 4x4
|
||||
512, # 8x8
|
||||
512, # 16x16
|
||||
256, # 32x32
|
||||
128, # 64x64
|
||||
64, # 128x128
|
||||
32, # 256x256
|
||||
16, # 512x512
|
||||
8, # 1024x1024
|
||||
]
|
||||
|
||||
# I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted.
|
||||
self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False)
|
||||
|
||||
in_out_pairs = list(zip(filters[:-1], filters[1:]))
|
||||
self.blocks = nn.ModuleList([])
|
||||
for ind in range(start_level, start_level+total_levels):
|
||||
in_chan, out_chan = in_out_pairs[ind]
|
||||
not_first = ind != start_level
|
||||
not_last = ind != (start_level+total_levels-1)
|
||||
block = GeneratorBlock(
|
||||
latent_dim,
|
||||
in_chan,
|
||||
out_chan,
|
||||
upsample=not_first,
|
||||
upsample_rgb=not_last,
|
||||
transfer_learning_mode=transfer_mode
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self, lr, styles):
|
||||
b, c, h, w = lr.shape
|
||||
if self.transfer_mode:
|
||||
with torch.no_grad():
|
||||
x = self.encoder(lr)
|
||||
else:
|
||||
x = self.encoder(lr)
|
||||
|
||||
styles = styles.transpose(0, 1)
|
||||
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
|
||||
if h != x.shape[-2]:
|
||||
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
|
||||
else:
|
||||
rgb = lr
|
||||
|
||||
for style, block in zip(styles, self.blocks):
|
||||
x, rgb = checkpoint(block, x, rgb, style, input_noise)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class StyledSrGenerator(nn.Module):
|
||||
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False):
|
||||
super().__init__()
|
||||
# Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
|
||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
|
||||
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
|
||||
self.l2 = nn.MSELoss()
|
||||
self.mixed_prob = .9
|
||||
self._init_weights()
|
||||
self.transfer_mode = transfer_mode
|
||||
self.initial_stride = initial_stride
|
||||
if transfer_mode:
|
||||
for p in self.parameters():
|
||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'):
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
for block in self.gen.blocks:
|
||||
nn.init.zeros_(block.to_noise1.weight)
|
||||
nn.init.zeros_(block.to_noise2.weight)
|
||||
nn.init.zeros_(block.to_noise1.bias)
|
||||
nn.init.zeros_(block.to_noise2.bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, f, h, w = x.shape
|
||||
|
||||
# Synthesize style latents from noise.
|
||||
style = torch.randn(b*2, self.gen.latent_dim).to(x.device)
|
||||
if self.transfer_mode:
|
||||
with torch.no_grad():
|
||||
w = self.vectorizer(style)
|
||||
else:
|
||||
w = self.vectorizer(style)
|
||||
|
||||
# Randomly distribute styles across layers
|
||||
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
|
||||
for j in range(b):
|
||||
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
|
||||
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
|
||||
w_styles[j] = w_styles[j*2]
|
||||
else:
|
||||
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
|
||||
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
|
||||
w_styles = w_styles[:b]
|
||||
|
||||
out = self.gen(x, w_styles)
|
||||
|
||||
# Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used
|
||||
# for regularization.
|
||||
out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area")
|
||||
if self.initial_stride > 1:
|
||||
x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area")
|
||||
l2_reg = self.l2(x, out_down)
|
||||
|
||||
return out, l2_reg, w_styles
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gen = StyledSrGenerator(128, 2)
|
||||
out = gen(torch.rand(1,3,64,64))
|
||||
print([o.shape for o in out])
|
||||
|
||||
|
||||
@register_model
|
||||
def register_styled_sr(opt_net, opt):
|
||||
return StyledSrGenerator(128,
|
||||
initial_stride=opt_get(opt_net, ['initial_stride'], 1),
|
||||
transfer_mode=opt_get(opt_net, ['transfer_mode'], False))
|
|
@ -1,411 +0,0 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
from contextlib import contextmanager, ExitStack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from kornia.filters import filter2D
|
||||
from linear_attention_transformer import ImageLinearAttention
|
||||
from torch import nn, Tensor
|
||||
from torch.autograd import grad as torch_grad
|
||||
from torch.nn import Parameter, init
|
||||
from torch.nn.modules.conv import _ConvNd
|
||||
|
||||
from models.styled_sr.transfer_primitives import TransferLinear
|
||||
|
||||
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
||||
|
||||
num_cores = multiprocessing.cpu_count()
|
||||
|
||||
# constants
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
class NanException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_average(self, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.reshape(x.shape[0], -1)
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(x) + x
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(x) * self.g
|
||||
|
||||
|
||||
class PermuteToFrom(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
out, loss = self.fn(x)
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
return out, loss
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
f = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer('f', f)
|
||||
|
||||
def forward(self, x):
|
||||
f = self.f
|
||||
f = f[None, None, :] * f[None, :, None]
|
||||
return filter2D(x, f, normalized=True)
|
||||
|
||||
|
||||
# one layer of self-attention and feedforward, for images
|
||||
|
||||
attn_and_ff = lambda chan: nn.Sequential(*[
|
||||
Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
|
||||
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
|
||||
])
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def null_context():
|
||||
yield
|
||||
|
||||
|
||||
def combine_contexts(contexts):
|
||||
@contextmanager
|
||||
def multi_contexts():
|
||||
with ExitStack() as stack:
|
||||
yield [stack.enter_context(ctx()) for ctx in contexts]
|
||||
|
||||
return multi_contexts
|
||||
|
||||
|
||||
def default(value, d):
|
||||
return value if exists(value) else d
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
while True:
|
||||
for i in iterable:
|
||||
yield i
|
||||
|
||||
|
||||
def cast_list(el):
|
||||
return el if isinstance(el, list) else [el]
|
||||
|
||||
|
||||
def is_empty(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.nelement() == 0
|
||||
return not exists(t)
|
||||
|
||||
|
||||
def raise_if_nan(t):
|
||||
if torch.isnan(t):
|
||||
raise NanException
|
||||
|
||||
|
||||
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
|
||||
if is_ddp:
|
||||
num_no_syncs = gradient_accumulate_every - 1
|
||||
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
|
||||
tail = [null_context]
|
||||
contexts = head + tail
|
||||
else:
|
||||
contexts = [null_context] * gradient_accumulate_every
|
||||
|
||||
for context in contexts:
|
||||
with context():
|
||||
yield
|
||||
|
||||
|
||||
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
||||
if fp16:
|
||||
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
|
||||
scaled_loss.backward(**kwargs)
|
||||
else:
|
||||
loss.backward(**kwargs)
|
||||
|
||||
def calc_pl_lengths(styles, images):
|
||||
device = images.device
|
||||
num_pixels = images.shape[2] * images.shape[3]
|
||||
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
|
||||
outputs = (images * pl_noise).sum()
|
||||
|
||||
pl_grads = torch_grad(outputs=outputs, inputs=styles,
|
||||
grad_outputs=torch.ones(outputs.shape, device=device),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
|
||||
|
||||
|
||||
def image_noise(n, im_size, device):
|
||||
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device)
|
||||
|
||||
|
||||
def leaky_relu(p=0.2):
|
||||
return nn.LeakyReLU(p, inplace=True)
|
||||
|
||||
|
||||
def evaluate_in_chunks(max_batch_size, model, *args):
|
||||
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
|
||||
chunked_outputs = [model(*i) for i in split_args]
|
||||
if len(chunked_outputs) == 1:
|
||||
return chunked_outputs[0]
|
||||
return torch.cat(chunked_outputs, dim=0)
|
||||
|
||||
|
||||
def set_requires_grad(model, bool):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = bool
|
||||
|
||||
|
||||
def slerp(val, low, high):
|
||||
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
||||
omega = torch.acos((low_norm * high_norm).sum(1))
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim))
|
||||
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def forward(self, input):
|
||||
if self.transfer_mode:
|
||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = self.weight
|
||||
return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
||||
|
||||
|
||||
class StyleVectorizer(nn.Module):
|
||||
def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()])
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.normalize(x, dim=1)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class RGBBlock(nn.Module):
|
||||
def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False):
|
||||
super().__init__()
|
||||
self.input_channel = input_channel
|
||||
self.to_style = nn.Linear(latent_dim, input_channel)
|
||||
|
||||
out_filters = 3 if not rgba else 4
|
||||
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode)
|
||||
|
||||
self.upsample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
||||
Blur()
|
||||
) if upsample else None
|
||||
|
||||
def forward(self, x, prev_rgb, istyle):
|
||||
b, c, h, w = x.shape
|
||||
style = self.to_style(istyle)
|
||||
x = self.conv(x, style)
|
||||
|
||||
if exists(prev_rgb):
|
||||
x = x + prev_rgb
|
||||
|
||||
if exists(self.upsample):
|
||||
x = self.upsample(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AdaptiveInstanceNorm(nn.Module):
|
||||
def __init__(self, in_channel, style_dim):
|
||||
super().__init__()
|
||||
from models.archs.arch_util import ConvGnLelu
|
||||
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
|
||||
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
|
||||
self.norm = nn.InstanceNorm2d(in_channel)
|
||||
|
||||
def forward(self, input, style):
|
||||
gamma = self.style2scale(style)
|
||||
beta = self.style2bias(style)
|
||||
out = self.norm(input)
|
||||
out = gamma * out + beta
|
||||
return out
|
||||
|
||||
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||
|
||||
def forward(self, image, noise):
|
||||
return image + self.weight * noise
|
||||
|
||||
|
||||
class EqualLR:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def compute_weight(self, module):
|
||||
weight = getattr(module, self.name + '_orig')
|
||||
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
||||
|
||||
return weight * math.sqrt(2 / fan_in)
|
||||
|
||||
@staticmethod
|
||||
def apply(module, name):
|
||||
fn = EqualLR(name)
|
||||
|
||||
weight = getattr(module, name)
|
||||
del module._parameters[name]
|
||||
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
||||
module.register_forward_pre_hook(fn)
|
||||
|
||||
return fn
|
||||
|
||||
def __call__(self, module, input):
|
||||
weight = self.compute_weight(module)
|
||||
setattr(module, self.name, weight)
|
||||
|
||||
|
||||
def equal_lr(module, name='weight'):
|
||||
EqualLR.apply(module, name)
|
||||
return module
|
||||
|
||||
|
||||
class Conv2DMod(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs):
|
||||
super().__init__()
|
||||
self.filters = out_chan
|
||||
self.demod = demod
|
||||
self.kernel = kernel
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
||||
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def _get_same_padding(self, size, kernel, dilation, stride):
|
||||
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
||||
|
||||
def forward(self, x, y):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
if self.transfer_mode:
|
||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = self.weight
|
||||
|
||||
w1 = y[:, None, :, None, None]
|
||||
w2 = weight[None, :, :, :, :]
|
||||
weights = w2 * (w1 + 1)
|
||||
|
||||
if self.demod:
|
||||
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
|
||||
weights = weights * d
|
||||
|
||||
x = x.reshape(1, -1, h, w)
|
||||
|
||||
_, _, *ws = weights.shape
|
||||
weights = weights.reshape(b * self.filters, *ws)
|
||||
|
||||
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
|
||||
x = F.conv2d(x, weights, padding=padding, groups=b)
|
||||
|
||||
x = x.reshape(-1, self.filters, h, w)
|
||||
return x
|
||||
|
||||
|
||||
class GeneratorBlock(nn.Module):
|
||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False,
|
||||
transfer_learning_mode=False):
|
||||
super().__init__()
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
||||
|
||||
self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode)
|
||||
self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
||||
self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode)
|
||||
|
||||
self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode)
|
||||
self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
||||
self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode)
|
||||
|
||||
self.activation = leaky_relu()
|
||||
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode)
|
||||
|
||||
self.transfer_learning_mode = transfer_learning_mode
|
||||
|
||||
def forward(self, x, prev_rgb, istyle, inoise):
|
||||
if exists(self.upsample):
|
||||
x = self.upsample(x)
|
||||
|
||||
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
||||
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
|
||||
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
|
||||
|
||||
style1 = self.to_style1(istyle)
|
||||
x = self.conv1(x, style1)
|
||||
x = self.activation(x + noise1)
|
||||
|
||||
style2 = self.to_style2(istyle)
|
||||
x = self.conv2(x, style2)
|
||||
x = self.activation(x + noise2)
|
||||
|
||||
rgb = self.to_rgb(x, prev_rgb, istyle)
|
||||
return x, rgb
|
|
@ -1,136 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter, init
|
||||
from torch.nn.modules.conv import _ConvNd
|
||||
from torch.nn.modules.utils import _ntuple
|
||||
|
||||
_pair = _ntuple(2)
|
||||
|
||||
class TransferConv2d(_ConvNd):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
dilation = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
transfer_mode: bool = False
|
||||
):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
dilation = _pair(dilation)
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
False, _pair(0), groups, bias, padding_mode)
|
||||
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
if self.transfer_mode:
|
||||
weight = weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = weight
|
||||
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.bias, self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self._conv_forward(input, self.weight)
|
||||
|
||||
|
||||
class TransferLinear(nn.Module):
|
||||
__constants__ = ['in_features', 'out_features']
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: Tensor
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
self.transfer_mode = transfer_mode
|
||||
if transfer_mode:
|
||||
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
||||
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
||||
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
||||
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.transfer_mode:
|
||||
weight = self.weight * self.transfer_scale + self.transfer_shift
|
||||
else:
|
||||
weight = self.weight
|
||||
return F.linear(input, weight, self.bias)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
|
||||
class TransferConvGnLelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
|
||||
super().__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
|
||||
if norm:
|
||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||
else:
|
||||
self.gn = None
|
||||
if activation:
|
||||
self.lelu = nn.LeakyReLU(negative_slope=.2)
|
||||
else:
|
||||
self.lelu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, TransferConv2d):
|
||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
||||
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.gn:
|
||||
x = self.gn(x)
|
||||
if self.lelu:
|
||||
return self.lelu(x)
|
||||
else:
|
||||
return x
|
|
@ -1,15 +0,0 @@
|
|||
import munch
|
||||
import torch
|
||||
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
@register_model
|
||||
def register_flownet2(opt_net):
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = 'load_path' in opt_net.keys()
|
||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||
netG = FlowNet2(args)
|
||||
if ld:
|
||||
sd = torch.load(opt_net['load_path'])
|
||||
netG.load_state_dict(sd['state_dict'])
|
|
@ -1,79 +0,0 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import sequential_checkpoint
|
||||
from models.arch_util import ConvGnSilu, make_layer
|
||||
|
||||
|
||||
class TecoResblock(nn.Module):
|
||||
def __init__(self, nf):
|
||||
super(TecoResblock, self).__init__()
|
||||
self.nf = nf
|
||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
|
||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return identity + x
|
||||
|
||||
|
||||
class TecoUpconv(nn.Module):
|
||||
def __init__(self, nf, scale):
|
||||
super(TecoUpconv, self).__init__()
|
||||
self.nf = nf
|
||||
self.scale = scale
|
||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
||||
x = self.conv3(x)
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
|
||||
# Main differences:
|
||||
# - Uses SiLU instead of ReLU
|
||||
# - Reference input is in HR space (just makes more sense)
|
||||
# - Doesn't use transposed convolutions - just uses interpolation instead.
|
||||
# - Upsample block is slightly more complicated.
|
||||
class TecoGen(nn.Module):
|
||||
def __init__(self, nf, scale):
|
||||
super(TecoGen, self).__init__()
|
||||
self.nf = nf
|
||||
self.scale = scale
|
||||
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
|
||||
res_layers = [TecoResblock(nf) for i in range(15)]
|
||||
upsample = TecoUpconv(nf, scale)
|
||||
everything = [fea_conv] + res_layers + [upsample]
|
||||
self.core = nn.Sequential(*everything)
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(x)
|
||||
join = torch.cat([x, ref], dim=1)
|
||||
join = sequential_checkpoint(self.core, 6, join)
|
||||
self.join = join.detach().clone() + .5
|
||||
return x + join
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
return {'branch_std': self.join.std()}
|
||||
|
||||
|
||||
@register_model
|
||||
def register_tecogen(opt_net, opt):
|
||||
return TecoGen(opt_net['nf'], opt_net['scale'])
|
|
@ -1,155 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from trainer.inject import Injector
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def create_injector(opt, env):
|
||||
type = opt['type']
|
||||
if type == 'igpt_resolve':
|
||||
return ResolveInjector(opt, env)
|
||||
return None
|
||||
|
||||
|
||||
class ResolveInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.gen = opt['generator']
|
||||
self.samples = opt['num_samples']
|
||||
self.temperature = opt['temperature']
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']].module
|
||||
img = state[self.input]
|
||||
b, c, h, w = img.shape
|
||||
qimg = gen.quantize(img)
|
||||
s, b = qimg.shape
|
||||
qimg = qimg[:s//2, :]
|
||||
output = qimg.repeat(1, self.samples)
|
||||
|
||||
pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output
|
||||
with torch.no_grad():
|
||||
for _ in range(s//2):
|
||||
logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
|
||||
logits = logits[-1, :, :] / self.temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
|
||||
output = torch.cat((output, pred), dim=0)
|
||||
output = gen.unquantize(output.reshape(h, w, -1))
|
||||
return {self.output: output.permute(2,3,0,1).contiguous()}
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
super(Block, self).__init__()
|
||||
self.ln_1 = nn.LayerNorm(embed_dim)
|
||||
self.ln_2 = nn.LayerNorm(embed_dim)
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(embed_dim, embed_dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(embed_dim * 4, embed_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
attn_mask = torch.full(
|
||||
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
|
||||
)
|
||||
attn_mask = torch.triu(attn_mask, diagonal=1)
|
||||
|
||||
x = self.ln_1(x)
|
||||
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
return x
|
||||
|
||||
|
||||
class iGPT2(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.centroids = nn.Parameter(
|
||||
torch.from_numpy(np.load(centroids_file)), requires_grad=False
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# start of sequence token
|
||||
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
|
||||
nn.init.normal_(self.sos)
|
||||
|
||||
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
|
||||
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.layers.append(Block(embed_dim, num_heads))
|
||||
|
||||
self.ln_f = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
|
||||
self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier.
|
||||
|
||||
def squared_euclidean_distance(self, a, b):
|
||||
b = torch.transpose(b, 0, 1)
|
||||
a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
|
||||
b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
|
||||
ab = torch.matmul(a, b)
|
||||
d = a2 - 2 * ab + b2
|
||||
return d
|
||||
|
||||
def quantize(self, x):
|
||||
b, c, h, w = x.shape
|
||||
# [B, C, H, W] => [B, H, W, C]
|
||||
x = x.permute(0, 2, 3, 1).contiguous()
|
||||
x = x.view(-1, c) # flatten to pixels
|
||||
d = self.squared_euclidean_distance(x, self.centroids)
|
||||
x = torch.argmin(d, 1)
|
||||
x = x.view(b, h, w)
|
||||
|
||||
# Reshape output to [seq_len, batch].
|
||||
x = x.view(x.shape[0], -1) # flatten images into sequences
|
||||
x = x.transpose(0, 1).contiguous() # to shape [seq len, batch]
|
||||
return x
|
||||
|
||||
def unquantize(self, x):
|
||||
return self.centroids[x]
|
||||
|
||||
def forward(self, x, already_quantized=False):
|
||||
"""
|
||||
Expect input as shape [b, c, h, w]
|
||||
"""
|
||||
|
||||
if not already_quantized:
|
||||
x = self.quantize(x)
|
||||
length, batch = x.shape
|
||||
|
||||
h = self.token_embeddings(x)
|
||||
|
||||
# prepend sos token
|
||||
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
|
||||
h = torch.cat([sos, h[:-1, :, :]], axis=0)
|
||||
|
||||
# add positional embeddings
|
||||
positions = torch.arange(length, device=x.device).unsqueeze(-1)
|
||||
h = h + self.position_embeddings(positions).expand_as(h)
|
||||
|
||||
# transformer
|
||||
for layer in self.layers:
|
||||
h = checkpoint(layer, h)
|
||||
|
||||
h = self.ln_f(h)
|
||||
|
||||
logits = self.head(h)
|
||||
|
||||
return logits, x
|
||||
|
||||
|
||||
@register_model
|
||||
def register_igpt2(opt_net, opt):
|
||||
return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2,
|
||||
opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
|
@ -1,42 +0,0 @@
|
|||
import numpy
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from data.torch_dataset import TorchDataset
|
||||
from models.classifiers.cifar_resnet_branched import ResNet
|
||||
from models.classifiers.cifar_resnet_branched import BasicBlock
|
||||
|
||||
if __name__ == '__main__':
|
||||
dopt = {
|
||||
'flip': True,
|
||||
'crop_sz': None,
|
||||
'dataset': 'cifar100',
|
||||
'image_size': 32,
|
||||
'normalize': False,
|
||||
'kwargs': {
|
||||
'root': 'E:\\4k6k\\datasets\\images\\cifar100',
|
||||
'download': True
|
||||
}
|
||||
}
|
||||
set = TorchDataset(dopt)
|
||||
loader = DataLoader(set, num_workers=0, batch_size=32)
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardw_10000.pth'))
|
||||
model.eval()
|
||||
|
||||
bins = [[] for _ in range(8)]
|
||||
for i, batch in enumerate(loader):
|
||||
logits, selector = model(batch['hq'], coarse_label=None, return_selector=True)
|
||||
for k, s in enumerate(selector):
|
||||
for j, b in enumerate(s):
|
||||
if b:
|
||||
bins[j].append(batch['labels'][k].item())
|
||||
if i > 10:
|
||||
break
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
fig, axs = plt.subplots(3,3)
|
||||
for i in range(8):
|
||||
axs[i%3, i//3].hist(numpy.asarray(bins[i]))
|
||||
plt.show()
|
||||
print('hi')
|
|
@ -1,70 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from utils import options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from utils.fdpl_util import dct_2d, extract_patches_2d
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
from utils.colors import rgb2ycbcr
|
||||
import torch.nn.functional as F
|
||||
|
||||
input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
|
||||
output_file = "fdpr_diff_means.pt"
|
||||
device = 'cuda'
|
||||
patch_size=128
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = option.parse(input_config, is_train=True)
|
||||
opt['dist'] = False
|
||||
|
||||
# Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
|
||||
dataset_opt = opt['datasets']['train']
|
||||
train_set = create_dataset(dataset_opt)
|
||||
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
||||
total_iters = int(opt['train']['niter'])
|
||||
total_epochs = int(math.ceil(total_iters / train_size))
|
||||
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
|
||||
print('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(train_set), train_size))
|
||||
|
||||
# calculate the perceptual weights
|
||||
master_diff = np.zeros((patch_size, patch_size))
|
||||
num_patches = 0
|
||||
all_diff_patches = []
|
||||
tq = tqdm(train_loader)
|
||||
sampled = 0
|
||||
for train_data in tq:
|
||||
if sampled > 200:
|
||||
break
|
||||
sampled += 1
|
||||
|
||||
im = rgb2ycbcr(train_data['hq'].double())
|
||||
im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
|
||||
size=im.shape[2:],
|
||||
mode="bicubic", align_corners=False))
|
||||
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||
patches_hr = dct_2d(patches_hr, norm='ortho')
|
||||
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||
patches_lr = dct_2d(patches_lr, norm='ortho')
|
||||
b, p, c, w, h = patches_hr.shape
|
||||
diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001)
|
||||
num_patches += b * p
|
||||
all_diff_patches.append(torch.sum(diffs, dim=(0, 1)))
|
||||
|
||||
diff_patches = torch.stack(all_diff_patches, dim=0)
|
||||
diff_means = torch.sum(diff_patches, dim=0) / num_patches
|
||||
|
||||
torch.save(diff_means, output_file)
|
||||
print(diff_means)
|
||||
|
||||
for i in range(3):
|
||||
fig, ax = plt.subplots()
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes('right', size='5%', pad=0.05)
|
||||
im = ax.imshow(diff_means[i].numpy())
|
||||
ax.set_title("mean_diff for channel %i" % (i,))
|
||||
fig.colorbar(im, cax=cax, orientation='vertical')
|
||||
plt.show()
|
||||
|
|
@ -1,411 +0,0 @@
|
|||
"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
|
||||
|
||||
import sys
|
||||
import os.path as osp
|
||||
import glob
|
||||
import pickle
|
||||
from multiprocessing import Pool
|
||||
import numpy as np
|
||||
import lmdb
|
||||
import cv2
|
||||
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
import data.util as data_util # noqa: E402
|
||||
import utils.util as util # noqa: E402
|
||||
|
||||
|
||||
def main():
|
||||
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
|
||||
mode = 'hq' # used for vimeo90k and REDS datasets
|
||||
# vimeo90k: GT | LR | flow
|
||||
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
|
||||
# train_sharp_flowx4
|
||||
if dataset == 'vimeo90k':
|
||||
vimeo90k(mode)
|
||||
elif dataset == 'REDS':
|
||||
REDS(mode)
|
||||
elif dataset == 'general':
|
||||
opt = {}
|
||||
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
||||
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
||||
opt['name'] = 'DIV2K800_sub_GT'
|
||||
general_image_folder(opt)
|
||||
elif dataset == 'DIV2K_demo':
|
||||
opt = {}
|
||||
## GT
|
||||
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
||||
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
||||
opt['name'] = 'DIV2K800_sub_GT'
|
||||
general_image_folder(opt)
|
||||
## LR
|
||||
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
|
||||
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
||||
opt['name'] = 'DIV2K800_sub_bicLRx4'
|
||||
general_image_folder(opt)
|
||||
elif dataset == 'test':
|
||||
test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS')
|
||||
|
||||
|
||||
def read_image_worker(path, key):
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
return (key, img)
|
||||
|
||||
|
||||
def general_image_folder(opt):
|
||||
"""Create lmdb for general image folders
|
||||
Users should define the keys, such as: '0321_s035' for DIV2K sub-images
|
||||
If all the images have the same resolution, it will only store one copy of resolution info.
|
||||
Otherwise, it will store every resolution info.
|
||||
"""
|
||||
#### configurations
|
||||
read_all_imgs = False # whether real all images to memory with multiprocessing
|
||||
# Set False for use limited memory
|
||||
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
||||
n_thread = 40
|
||||
########################################################
|
||||
img_folder = opt['img_folder']
|
||||
lmdb_save_path = opt['lmdb_save_path']
|
||||
meta_info = {'name': opt['name']}
|
||||
if not lmdb_save_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
||||
if osp.exists(lmdb_save_path):
|
||||
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
||||
sys.exit(1)
|
||||
|
||||
#### read all the image paths to a list
|
||||
print('Reading image path list ...')
|
||||
all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
|
||||
keys = []
|
||||
for img_path in all_img_list:
|
||||
keys.append(osp.splitext(osp.basename(img_path))[0])
|
||||
|
||||
if read_all_imgs:
|
||||
#### read all images to memory (multiprocessing)
|
||||
dataset = {} # store all image data. list cannot keep the order, use dict
|
||||
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
||||
pbar = util.ProgressBar(len(all_img_list))
|
||||
|
||||
def mycallback(arg):
|
||||
'''get the image data and update pbar'''
|
||||
key = arg[0]
|
||||
dataset[key] = arg[1]
|
||||
pbar.update('Reading {}'.format(key))
|
||||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(all_img_list, keys):
|
||||
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
||||
|
||||
#### create lmdb environment
|
||||
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
||||
print('data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(all_img_list)
|
||||
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
||||
|
||||
#### write data to lmdb
|
||||
pbar = util.ProgressBar(len(all_img_list))
|
||||
txn = env.begin(write=True)
|
||||
resolutions = []
|
||||
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
||||
pbar.update('Write {}'.format(key))
|
||||
key_byte = key.encode('ascii')
|
||||
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if data.ndim == 2:
|
||||
H, W = data.shape
|
||||
C = 1
|
||||
else:
|
||||
H, W, C = data.shape
|
||||
txn.put(key_byte, data)
|
||||
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
|
||||
if not read_all_imgs and idx % BATCH == 0:
|
||||
txn.commit()
|
||||
txn = env.begin(write=True)
|
||||
txn.commit()
|
||||
env.close()
|
||||
print('Finish writing lmdb.')
|
||||
|
||||
#### create meta information
|
||||
# check whether all the images are the same size
|
||||
assert len(keys) == len(resolutions)
|
||||
if len(set(resolutions)) <= 1:
|
||||
meta_info['resolution'] = [resolutions[0]]
|
||||
meta_info['keys'] = keys
|
||||
print('All images have the same resolution. Simplify the meta info.')
|
||||
else:
|
||||
meta_info['resolution'] = resolutions
|
||||
meta_info['keys'] = keys
|
||||
print('Not all images have the same resolution. Save meta info for each image.')
|
||||
|
||||
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
||||
print('Finish creating lmdb meta info.')
|
||||
|
||||
|
||||
def vimeo90k(mode):
|
||||
"""Create lmdb for the Vimeo90K dataset, each image with a fixed size
|
||||
GT: [3, 256, 448]
|
||||
Now only need the 4th frame, e.g., 00001_0001_4
|
||||
LR: [3, 64, 112]
|
||||
1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
|
||||
key:
|
||||
Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
|
||||
|
||||
flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
|
||||
Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
|
||||
Flow map is quantized by mmcv and saved in png format
|
||||
"""
|
||||
#### configurations
|
||||
read_all_imgs = False # whether real all images to memory with multiprocessing
|
||||
# Set False for use limited memory
|
||||
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
||||
if mode == 'hq':
|
||||
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
|
||||
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
||||
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
||||
H_dst, W_dst = 256, 448
|
||||
elif mode == 'LR':
|
||||
img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
|
||||
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
||||
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
||||
H_dst, W_dst = 64, 112
|
||||
elif mode == 'flow':
|
||||
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
|
||||
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
|
||||
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
||||
H_dst, W_dst = 128, 112
|
||||
else:
|
||||
raise ValueError('Wrong dataset mode: {}'.format(mode))
|
||||
n_thread = 40
|
||||
########################################################
|
||||
if not lmdb_save_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
||||
if osp.exists(lmdb_save_path):
|
||||
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
||||
sys.exit(1)
|
||||
|
||||
#### read all the image paths to a list
|
||||
print('Reading image path list ...')
|
||||
with open(txt_file) as f:
|
||||
train_l = f.readlines()
|
||||
train_l = [v.strip() for v in train_l]
|
||||
all_img_list = []
|
||||
keys = []
|
||||
for line in train_l:
|
||||
folder = line.split('/')[0]
|
||||
sub_folder = line.split('/')[1]
|
||||
all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
|
||||
if mode == 'flow':
|
||||
for j in range(1, 4):
|
||||
keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
|
||||
keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
|
||||
else:
|
||||
for j in range(7):
|
||||
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
|
||||
all_img_list = sorted(all_img_list)
|
||||
keys = sorted(keys)
|
||||
if mode == 'hq': # only read the 4th frame for the GT mode
|
||||
print('Only keep the 4th frame.')
|
||||
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
|
||||
keys = [v for v in keys if v.endswith('_4')]
|
||||
|
||||
if read_all_imgs:
|
||||
#### read all images to memory (multiprocessing)
|
||||
dataset = {} # store all image data. list cannot keep the order, use dict
|
||||
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
||||
pbar = util.ProgressBar(len(all_img_list))
|
||||
|
||||
def mycallback(arg):
|
||||
"""get the image data and update pbar"""
|
||||
key = arg[0]
|
||||
dataset[key] = arg[1]
|
||||
pbar.update('Reading {}'.format(key))
|
||||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(all_img_list, keys):
|
||||
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
||||
|
||||
#### write data to lmdb
|
||||
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
||||
print('data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(all_img_list)
|
||||
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
||||
txn = env.begin(write=True)
|
||||
pbar = util.ProgressBar(len(all_img_list))
|
||||
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
||||
pbar.update('Write {}'.format(key))
|
||||
key_byte = key.encode('ascii')
|
||||
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if 'flow' in mode:
|
||||
H, W = data.shape
|
||||
assert H == H_dst and W == W_dst, 'different shape.'
|
||||
else:
|
||||
H, W, C = data.shape
|
||||
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
||||
txn.put(key_byte, data)
|
||||
if not read_all_imgs and idx % BATCH == 0:
|
||||
txn.commit()
|
||||
txn = env.begin(write=True)
|
||||
txn.commit()
|
||||
env.close()
|
||||
print('Finish writing lmdb.')
|
||||
|
||||
#### create meta information
|
||||
meta_info = {}
|
||||
if mode == 'hq':
|
||||
meta_info['name'] = 'Vimeo90K_train_GT'
|
||||
elif mode == 'lq':
|
||||
meta_info['name'] = 'Vimeo90K_train_LR'
|
||||
elif mode == 'flow':
|
||||
meta_info['name'] = 'Vimeo90K_train_flowx4'
|
||||
channel = 1 if 'flow' in mode else 3
|
||||
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
||||
key_set = set()
|
||||
for key in keys:
|
||||
if mode == 'flow':
|
||||
a, b, _, _ = key.split('_')
|
||||
else:
|
||||
a, b, _ = key.split('_')
|
||||
key_set.add('{}_{}'.format(a, b))
|
||||
meta_info['keys'] = list(key_set)
|
||||
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
||||
print('Finish creating lmdb meta info.')
|
||||
|
||||
|
||||
def REDS(mode):
|
||||
"""Create lmdb for the REDS dataset, each image with a fixed size
|
||||
GT: [3, 720, 1280], key: 000_00000000
|
||||
LR: [3, 180, 320], key: 000_00000000
|
||||
key: 000_00000000
|
||||
|
||||
flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
|
||||
Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
|
||||
Flow map is quantized by mmcv and saved in png format
|
||||
"""
|
||||
#### configurations
|
||||
read_all_imgs = False # whether real all images to memory with multiprocessing
|
||||
# Set False for use limited memory
|
||||
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
||||
if mode == 'train_sharp':
|
||||
img_folder = '../../datasets/REDS/train_sharp'
|
||||
lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
|
||||
H_dst, W_dst = 720, 1280
|
||||
elif mode == 'train_sharp_bicubic':
|
||||
img_folder = '../../datasets/REDS/train_sharp_bicubic'
|
||||
lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
||||
H_dst, W_dst = 180, 320
|
||||
elif mode == 'train_blur_bicubic':
|
||||
img_folder = '../../datasets/REDS/train_blur_bicubic'
|
||||
lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
|
||||
H_dst, W_dst = 180, 320
|
||||
elif mode == 'train_blur':
|
||||
img_folder = '../../datasets/REDS/train_blur'
|
||||
lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
|
||||
H_dst, W_dst = 720, 1280
|
||||
elif mode == 'train_blur_comp':
|
||||
img_folder = '../../datasets/REDS/train_blur_comp'
|
||||
lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
|
||||
H_dst, W_dst = 720, 1280
|
||||
elif mode == 'train_sharp_flowx4':
|
||||
img_folder = '../../datasets/REDS/train_sharp_flowx4'
|
||||
lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
|
||||
H_dst, W_dst = 360, 320
|
||||
n_thread = 40
|
||||
########################################################
|
||||
if not lmdb_save_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
|
||||
if osp.exists(lmdb_save_path):
|
||||
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
|
||||
sys.exit(1)
|
||||
|
||||
#### read all the image paths to a list
|
||||
print('Reading image path list ...')
|
||||
all_img_list = data_util._get_paths_from_images(img_folder)
|
||||
keys = []
|
||||
for img_path in all_img_list:
|
||||
split_rlt = img_path.split('/')
|
||||
folder = split_rlt[-2]
|
||||
img_name = split_rlt[-1].split('.png')[0]
|
||||
keys.append(folder + '_' + img_name)
|
||||
|
||||
if read_all_imgs:
|
||||
#### read all images to memory (multiprocessing)
|
||||
dataset = {} # store all image data. list cannot keep the order, use dict
|
||||
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
|
||||
pbar = util.ProgressBar(len(all_img_list))
|
||||
|
||||
def mycallback(arg):
|
||||
'''get the image data and update pbar'''
|
||||
key = arg[0]
|
||||
dataset[key] = arg[1]
|
||||
pbar.update('Reading {}'.format(key))
|
||||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(all_img_list, keys):
|
||||
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
|
||||
|
||||
#### create lmdb environment
|
||||
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
|
||||
print('data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(all_img_list)
|
||||
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
||||
|
||||
#### write data to lmdb
|
||||
pbar = util.ProgressBar(len(all_img_list))
|
||||
txn = env.begin(write=True)
|
||||
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
|
||||
pbar.update('Write {}'.format(key))
|
||||
key_byte = key.encode('ascii')
|
||||
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if 'flow' in mode:
|
||||
H, W = data.shape
|
||||
assert H == H_dst and W == W_dst, 'different shape.'
|
||||
else:
|
||||
H, W, C = data.shape
|
||||
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
|
||||
txn.put(key_byte, data)
|
||||
if not read_all_imgs and idx % BATCH == 0:
|
||||
txn.commit()
|
||||
txn = env.begin(write=True)
|
||||
txn.commit()
|
||||
env.close()
|
||||
print('Finish writing lmdb.')
|
||||
|
||||
#### create meta information
|
||||
meta_info = {}
|
||||
meta_info['name'] = 'REDS_{}_wval'.format(mode)
|
||||
channel = 1 if 'flow' in mode else 3
|
||||
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
|
||||
meta_info['keys'] = keys
|
||||
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
|
||||
print('Finish creating lmdb meta info.')
|
||||
|
||||
|
||||
def test_lmdb(dataroot, dataset='REDS'):
|
||||
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
|
||||
print('Name: ', meta_info['name'])
|
||||
print('Resolution: ', meta_info['resolution'])
|
||||
print('# keys: ', len(meta_info['keys']))
|
||||
# read one image
|
||||
if dataset == 'vimeo90k':
|
||||
key = '00001_0001_4'
|
||||
else:
|
||||
key = '000_00000000'
|
||||
print('Reading {} for test.'.format(key))
|
||||
with env.begin(write=False) as txn:
|
||||
buf = txn.get(key.encode('ascii'))
|
||||
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
||||
C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
|
||||
img = img_flat.reshape(H, W, C)
|
||||
cv2.imwrite('test.png', img)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,22 +0,0 @@
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
if __name__ == "__main__":
|
||||
writer = SummaryWriter("../experiments/recovered_tb")
|
||||
f = open("../experiments/recovered_tb.txt", encoding="utf8")
|
||||
console = f.readlines()
|
||||
search_terms = [
|
||||
("iter", ", iter: ", ", lr:"),
|
||||
("l_g_total", " l_g_total: ", " switch_temperature:"),
|
||||
("l_d_fake", "l_d_fake: ", " D_fake:")
|
||||
]
|
||||
iter = 0
|
||||
for line in console:
|
||||
if " - INFO: [epoch:" not in line:
|
||||
continue
|
||||
for name, start, end in search_terms:
|
||||
val = line[line.find(start)+len(start):line.find(end)].replace(",", "")
|
||||
if name == "iter":
|
||||
iter = int(val)
|
||||
else:
|
||||
writer.add_scalar(name, float(val), iter)
|
||||
writer.close()
|
|
@ -1,19 +0,0 @@
|
|||
import os
|
||||
import glob
|
||||
|
||||
|
||||
def main():
|
||||
folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4'
|
||||
DIV2K(folder)
|
||||
print('Finished.')
|
||||
|
||||
|
||||
def DIV2K(path):
|
||||
img_path_l = glob.glob(os.path.join(path, '*'))
|
||||
for img_path in img_path_l:
|
||||
new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
|
||||
os.rename(img_path, new_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,83 +0,0 @@
|
|||
import os.path as osp
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import os
|
||||
|
||||
import torchvision
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
srg_analyze = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_psnr_approximator.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items()
|
||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
|
||||
#### Create test dataset and dataloader
|
||||
test_loaders = []
|
||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||
dataset_opt['n_workers'] = 0
|
||||
test_set = create_dataset(dataset_opt)
|
||||
test_loader = create_dataloader(test_set, dataset_opt, opt)
|
||||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
|
||||
model = ExtensibleTrainer(opt)
|
||||
for test_loader in test_loaders:
|
||||
test_set_name = test_loader.dataset.opt['name']
|
||||
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
||||
test_start_time = time.time()
|
||||
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
|
||||
util.mkdir(dataset_dir)
|
||||
|
||||
dst_path = "F:\\playground"
|
||||
[os.makedirs(osp.join(dst_path, str(i)), exist_ok=True) for i in range(10)]
|
||||
|
||||
corruptions = ['none', 'color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise',
|
||||
'jpeg-medium', 'jpeg-broad', 'jpeg-normal', 'saturation', 'lq_resampling',
|
||||
'lq_resampling4x']
|
||||
c_counter = 0
|
||||
test_set.corruptor.num_corrupts = 0
|
||||
test_set.corruptor.random_corruptions = []
|
||||
test_set.corruptor.fixed_corruptions = [corruptions[0]]
|
||||
corruption_mse = [(0,0) for _ in corruptions]
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
batch_size = opt['datasets']['train']['batch_size']
|
||||
for data in tq:
|
||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
model.feed_data(data, need_GT=need_GT)
|
||||
model.test()
|
||||
est_psnr = torch.mean(model.eval_state['psnr_approximate'][0], dim=[1,2,3])
|
||||
for i in range(est_psnr.shape[0]):
|
||||
im_path = data['GT_path'][i]
|
||||
torchvision.utils.save_image(model.eval_state['lq'][0][i], osp.join(dst_path, str(int(est_psnr[i]*10)), osp.basename(im_path)))
|
||||
#shutil.copy(im_path, osp.join(dst_path, str(int(est_psnr[i]*10))))
|
||||
|
||||
last_mse, last_ctr = corruption_mse[c_counter % len(corruptions)]
|
||||
corruption_mse[c_counter % len(corruptions)] = (last_mse + torch.sum(est_psnr).item(), last_ctr + 1)
|
||||
c_counter += 1
|
||||
test_set.corruptor.fixed_corruptions = [corruptions[c_counter % len(corruptions)]]
|
||||
if c_counter % 100 == 0:
|
||||
for i, (mse, ctr) in enumerate(corruption_mse):
|
||||
print("%s: %f" % (corruptions[i], mse / (ctr * batch_size)))
|
|
@ -1,136 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# note: all dct related functions are either exactly as or based on those
|
||||
# at https://github.com/zh217/torch-dct
|
||||
def dct(x, norm=None):
|
||||
"""
|
||||
Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param x: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the DCT-II of the signal over the last dimension
|
||||
"""
|
||||
x_shape = x.shape
|
||||
N = x_shape[-1]
|
||||
x = x.contiguous().view(-1, N)
|
||||
|
||||
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
||||
|
||||
Vc = torch.rfft(v, 1, onesided=False)
|
||||
|
||||
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
||||
W_r = torch.cos(k)
|
||||
W_i = torch.sin(k)
|
||||
|
||||
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
||||
|
||||
if norm == 'ortho':
|
||||
V[:, 0] /= np.sqrt(N) * 2
|
||||
V[:, 1:] /= np.sqrt(N / 2) * 2
|
||||
|
||||
V = 2 * V.view(*x_shape)
|
||||
|
||||
return V
|
||||
|
||||
def idct(X, norm=None):
|
||||
"""
|
||||
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
||||
Our definition of idct is that idct(dct(x)) == x
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param X: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the inverse DCT-II of the signal over the last dimension
|
||||
"""
|
||||
|
||||
x_shape = X.shape
|
||||
N = x_shape[-1]
|
||||
|
||||
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
|
||||
|
||||
if norm == 'ortho':
|
||||
X_v[:, 0] *= np.sqrt(N) * 2
|
||||
X_v[:, 1:] *= np.sqrt(N / 2) * 2
|
||||
|
||||
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
|
||||
W_r = torch.cos(k)
|
||||
W_i = torch.sin(k)
|
||||
|
||||
V_t_r = X_v
|
||||
V_t_r = V_t_r.to(device)
|
||||
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
|
||||
V_t_i = V_t_i.to(device)
|
||||
|
||||
V_r = V_t_r * W_r - V_t_i * W_i
|
||||
V_i = V_t_r * W_i + V_t_i * W_r
|
||||
|
||||
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
|
||||
|
||||
v = torch.irfft(V, 1, onesided=False)
|
||||
x = v.new_zeros(v.shape)
|
||||
x[:, ::2] += v[:, :N - (N // 2)]
|
||||
x[:, 1::2] += v.flip([1])[:, :N // 2]
|
||||
|
||||
return x.view(*x_shape)
|
||||
|
||||
def dct_2d(x, norm=None):
|
||||
"""
|
||||
2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param x: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the DCT-II of the signal over the last 2 dimensions
|
||||
"""
|
||||
X1 = dct(x, norm=norm)
|
||||
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
||||
return X2.transpose(-1, -2)
|
||||
|
||||
def idct_2d(X, norm=None):
|
||||
"""
|
||||
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
||||
Our definition of idct is that idct_2d(dct_2d(x)) == x
|
||||
For the meaning of the parameter `norm`, see:
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
:param X: the input signal
|
||||
:param norm: the normalization, None or 'ortho'
|
||||
:return: the DCT-II of the signal over the last 2 dimensions
|
||||
"""
|
||||
x1 = idct(X, norm=norm)
|
||||
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
||||
return x2.transpose(-1, -2)
|
||||
|
||||
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False):
|
||||
"""
|
||||
source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78
|
||||
"""
|
||||
patch_H, patch_W = patch_shape[0], patch_shape[1]
|
||||
if(img.size(2) < patch_H):
|
||||
num_padded_H_Top = (patch_H - img.size(2))//2
|
||||
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
|
||||
padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0)
|
||||
img = padding_H(img)
|
||||
if(img.size(3) < patch_W):
|
||||
num_padded_W_Left = (patch_W - img.size(3))//2
|
||||
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
|
||||
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0)
|
||||
img = padding_W(img)
|
||||
step_int = [0, 0]
|
||||
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
|
||||
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
|
||||
patches_fold_H = img.unfold(2, patch_H, step_int[0])
|
||||
if((img.size(2) - patch_H) % step_int[0] != 0):
|
||||
patches_fold_H = torch.cat((patches_fold_H,
|
||||
img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2)
|
||||
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
|
||||
if((img.size(3) - patch_W) % step_int[1] != 0):
|
||||
patches_fold_HW = torch.cat((patches_fold_HW,
|
||||
patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3)
|
||||
patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5)
|
||||
patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W)
|
||||
if(batch_first):
|
||||
patches = patches.permute(1, 0, 2, 3, 4)
|
||||
return patches
|
Loading…
Reference in New Issue
Block a user