Clean up codebase

Remove stuff that I'm likely not going to use again (or generally failed experiments)
This commit is contained in:
James Betker 2021-09-29 09:21:44 -06:00
parent 4d1a42e944
commit 55b58fb67f
24 changed files with 0 additions and 2963 deletions

View File

@ -1,507 +0,0 @@
import math
import copy
import os
import random
from functools import wraps, partial
from math import floor
import torch
import torchvision
from torch import nn, einsum
import torch.nn.functional as F
from kornia import augmentation as augs
from kornia import filters, color
from einops import rearrange
# helper functions
from trainer.networks import register_model, create_model
def identity(t):
return t
def default(val, def_val):
return def_val if val is None else val
def rand_true(prob):
return random.random() < prob
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
def cutout_coordinates(image, ratio_range = (0.6, 0.8)):
_, _, orig_h, orig_w = image.shape
ratio_lo, ratio_hi = ratio_range
random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo)
w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h)
coor_x = floor((orig_w - w) * random.random())
coor_y = floor((orig_h - h) * random.random())
return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio
def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
shape = image.shape
output_size = default(output_size, shape[2:])
(y0, y1), (x0, x1) = coordinates
cutout_image = image[:, :, y0:y1, x0:x1]
return F.interpolate(cutout_image, size = output_size, mode = mode)
def scale_coords(coords, scale):
output = [[0,0],[0,0]]
for j in range(2):
for k in range(2):
output[j][k] = int(coords[j][k] / scale)
return output
def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'):
blank = torch.zeros_like(image)
coordinates = scale_coords(coordinates, scale_reduction)
(y0, y1), (x0, x1) = coordinates
orig_cutout_shape = (y1-y0, x1-x0)
if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0:
return None
un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode)
blank[:,:,y0:y1,x0:x1] = un_resized_img
return blank
def compute_shared_coords(coords1, coords2, scale_reduction):
(y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction)
(y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction)
shared = ((max(y1_t, y2_t), min(y1_b, y2_b)),
(max(x1_l, x2_l), min(x1_r, x2_r)))
for s in shared:
if s == 0:
return None
return shared
def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode):
# Unflip the pixel projections
proj_pixel_one = flip_image_one_fn(proj_pixel_one)
proj_pixel_two = flip_image_two_fn(proj_pixel_two)
# Undo the cutout and resize, taking into account the scale reduction applied by the encoder.
scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1]
proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction,
mode=interp_mode)
proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction,
mode=interp_mode)
if proj_pixel_one is None or proj_pixel_two is None:
print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!")
return None
# Compute the shared coordinates for the two cutouts:
shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction)
if shared_coords is None:
print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..")
return None
(yt, yb), (xl, xr) = shared_coords
return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr]
# augmentation utils
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# exponential moving average
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# loss fn
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
# classes
class MLP(nn.Module):
def __init__(self, chan, chan_out = 256, inner_dim = 2048):
super().__init__()
self.net = nn.Sequential(
nn.Linear(chan, inner_dim),
nn.BatchNorm1d(inner_dim),
nn.ReLU(),
nn.Linear(inner_dim, chan_out)
)
def forward(self, x):
return self.net(x)
class ConvMLP(nn.Module):
def __init__(self, chan, chan_out = 256, inner_dim = 2048):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, inner_dim, 1),
nn.BatchNorm2d(inner_dim),
nn.ReLU(),
nn.Conv2d(inner_dim, chan_out, 1)
)
def forward(self, x):
return self.net(x)
class PPM(nn.Module):
def __init__(
self,
*,
chan,
num_layers = 1,
gamma = 2):
super().__init__()
self.gamma = gamma
if num_layers == 0:
self.transform_net = nn.Identity()
elif num_layers == 1:
self.transform_net = nn.Conv2d(chan, chan, 1)
elif num_layers == 2:
self.transform_net = nn.Sequential(
nn.Conv2d(chan, chan, 1),
nn.BatchNorm2d(chan),
nn.ReLU(),
nn.Conv2d(chan, chan, 1)
)
else:
raise ValueError('num_layers must be one of 0, 1, or 2')
def forward(self, x):
xi = x[:, :, :, :, None, None]
xj = x[:, :, None, None, :, :]
similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma
transform_out = self.transform_net(x)
out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out)
return out
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(
self,
*,
net,
instance_projection_size,
instance_projection_hidden_size,
pix_projection_size,
pix_projection_hidden_size,
layer_pixel = -2,
layer_instance = -2
):
super().__init__()
self.net = net
self.layer_pixel = layer_pixel
self.layer_instance = layer_instance
self.pixel_projector = None
self.instance_projector = None
self.instance_projection_size = instance_projection_size
self.instance_projection_hidden_size = instance_projection_hidden_size
self.pix_projection_size = pix_projection_size
self.pix_projection_hidden_size = pix_projection_hidden_size
self.hidden_pixel = None
self.hidden_instance = None
self.hook_registered = False
def _find_layer(self, layer_id):
if type(layer_id) == str:
modules = dict([*self.net.named_modules()])
return modules.get(layer_id, None)
elif type(layer_id) == int:
children = [*self.net.children()]
return children[layer_id]
return None
def _hook(self, attr_name, _, __, output):
setattr(self, attr_name, output)
def _register_hook(self):
pixel_layer = self._find_layer(self.layer_pixel)
instance_layer = self._find_layer(self.layer_instance)
assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found'
assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found'
pixel_layer.register_forward_hook(partial(self._hook, 'hidden_pixel'))
instance_layer.register_forward_hook(partial(self._hook, 'hidden_instance'))
self.hook_registered = True
@singleton('pixel_projector')
def _get_pixel_projector(self, hidden):
_, dim, *_ = hidden.shape
projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size)
return projector.to(hidden)
@singleton('instance_projector')
def _get_instance_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
if not self.hook_registered:
self._register_hook()
_ = self.net(x)
hidden_pixel = self.hidden_pixel
hidden_instance = self.hidden_instance
self.hidden_pixel = None
self.hidden_instance = None
assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output'
assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output'
return hidden_pixel, hidden_instance
def forward(self, x):
pixel_representation, instance_representation = self.get_representation(x)
instance_representation = instance_representation.flatten(1)
pixel_projector = self._get_pixel_projector(pixel_representation)
instance_projector = self._get_instance_projector(instance_representation)
pixel_projection = pixel_projector(pixel_representation)
instance_projection = instance_projector(instance_representation)
return pixel_projection, instance_projection
# main class
class PixelCL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer_pixel = -2,
hidden_layer_instance = -2,
instance_projection_size = 256,
instance_projection_hidden_size = 2048,
pix_projection_size = 256,
pix_projection_hidden_size = 2048,
augment_fn = None,
augment_fn2 = None,
prob_rand_hflip = 0.25,
moving_average_decay = 0.99,
ppm_num_layers = 1,
ppm_gamma = 2,
distance_thres = 0.7,
similarity_temperature = 0.3,
cutout_ratio_range = (0.6, 0.8),
cutout_interpolate_mode = 'nearest',
coord_cutout_interpolate_mode = 'bilinear',
max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
):
super().__init__()
DEFAULT_AUG = nn.Sequential(
RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomSolarize(p=0.5),
# Normalize left out because it should be done at the model level.
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
self.prob_rand_hflip = prob_rand_hflip
self.online_encoder = NetWrapper(
net = net,
instance_projection_size = instance_projection_size,
instance_projection_hidden_size = instance_projection_hidden_size,
pix_projection_size = pix_projection_size,
pix_projection_hidden_size = pix_projection_hidden_size,
layer_pixel = hidden_layer_pixel,
layer_instance = hidden_layer_instance
)
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.distance_thres = distance_thres
self.similarity_temperature = similarity_temperature
# This requirement is due to the way that these are processed, not a hard requirement.
assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
self.max_latent_dim = max_latent_dim
self.propagate_pixels = PPM(
chan = pix_projection_size,
num_layers = ppm_num_layers,
gamma = ppm_gamma
)
self.cutout_ratio_range = cutout_ratio_range
self.cutout_interpolate_mode = cutout_interpolate_mode
self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
# instance level predictor
self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
def update_moving_average(self):
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def forward(self, x):
shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip
rand_flip_fn = lambda t: torch.flip(t, dims = (-1,))
flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip)
flip_image_one_fn = rand_flip_fn if flip_image_one else identity
flip_image_two_fn = rand_flip_fn if flip_image_two else identity
cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range)
cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range)
image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode)
image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode)
image_one_cutout = flip_image_one_fn(image_one_cutout)
image_two_cutout = flip_image_two_fn(image_two_cutout)
image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout)
self.aug1 = image_one_cutout.detach().clone()
self.aug2 = image_two_cutout.detach().clone()
proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
if proj_pixel_one is None or proj_pixel_two is None:
positive_pixel_pairs = 0
else:
positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2]
with torch.no_grad():
target_encoder = self._get_target_encoder()
target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
b, c, pp_h, pp_w = proj_pixel_one.shape
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
extracted = []
for l in latents:
l = l.reshape(b, c, pp_h * pp_w)
l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
# For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square".
# Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards
# to the original image structure.
sqdim = int(math.sqrt(self.max_latent_dim))
extracted.append(l.reshape(b, c, sqdim, sqdim))
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
# flatten all the pixel projections
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
# get instance level loss
pred_instance_one = self.online_predictor(proj_instance_one)
pred_instance_two = self.online_predictor(proj_instance_two)
loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
instance_loss = (loss_instance_one + loss_instance_two).mean()
if positive_pixel_pairs == 0:
return instance_loss, 0
# calculate pix pro loss
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
loss_pixpro_one_two = - propagated_similarity_one_two.mean()
loss_pixpro_two_one = - propagated_similarity_two_one.mean()
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
return instance_loss, pix_loss, positive_pixel_pairs
# Allows visualizing what the augmentor is up to.
def visual_dbg(self, step, path):
if not hasattr(self, 'aug1'):
return
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
@register_model
def register_pixel_contrastive_learner(opt_net, opt):
subnet = create_model(opt, opt_net['subnet'])
kwargs = opt_net['kwargs']
if 'subnet_pretrain_path' in opt_net.keys():
sd = torch.load(opt_net['subnet_pretrain_path'])
subnet.load_state_dict(sd, strict=False)
return PixelCL(subnet, **kwargs)

View File

@ -1,152 +0,0 @@
# Resnet implementation that adds a u-net style up-conversion component to output values at a
# specified pixel density.
#
# The downsampling part of the network is compatible with the built-in torch resnet for use in
# transfer learning.
#
# Only resnet50 currently supported.
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
from torchvision.models.utils import load_state_dict_from_url
import torchvision
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class ReverseBottleneck(nn.Module):
def __init__(self, inplanes, planes, groups=1, passthrough=False,
base_width=64, dilation=1, norm_layer=None):
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.passthrough = passthrough
if passthrough:
self.integrate = conv1x1(inplanes*2, inplanes)
self.bn_integrate = norm_layer(inplanes)
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, groups, dilation)
self.bn2 = norm_layer(width)
self.residual_upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv1x1(width, width),
norm_layer(width),
)
self.conv3 = conv1x1(width, planes)
self.bn3 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv1x1(inplanes, planes),
norm_layer(planes),
)
def forward(self, x, passthrough=None):
if self.passthrough:
x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1)))
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.residual_upsample(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.upsample(x)
out = out + identity
out = self.relu(out)
return out
class UResNet50(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, out_dim=128):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
replace_stride_with_dilation, norm_layer)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
'''
# For reference:
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
'''
uplayers = []
inplanes = 2048
first = True
for i in range(2):
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
inplanes = inplanes // 2
first = False
self.uplayers = nn.ModuleList(uplayers)
self.tail = nn.Sequential(conv1x1(1024, 512),
norm_layer(512),
nn.ReLU(),
conv3x3(512, 512),
norm_layer(512),
nn.ReLU(),
conv1x1(512, out_dim))
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
def _forward_impl(self, x):
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl,
# except using checkpoints on the body conv layers.
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x1 = checkpoint(self.layer1, x)
x2 = checkpoint(self.layer2, x1)
x3 = checkpoint(self.layer3, x2)
x4 = checkpoint(self.layer4, x3)
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
x = checkpoint(self.uplayers[0], x4)
x = checkpoint(self.uplayers[1], x, x3)
#x = checkpoint(self.uplayers[2], x, x2)
#x = checkpoint(self.uplayers[3], x, x1)
return checkpoint(self.tail, torch.cat([x, x2], dim=1))
def forward(self, x):
return self._forward_impl(x)
@register_model
def register_u_resnet50(opt_net, opt):
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
if opt_get(opt_net, ['use_pretrained_base'], False):
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == '__main__':
model = UResNet50(Bottleneck, [3,4,6,3])
samp = torch.rand(1,3,224,224)
model(samp)
# For pixpro: attach to "tail.3"

View File

@ -1,87 +0,0 @@
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
from torchvision.models.utils import load_state_dict_from_url
import torchvision
from models.arch_util import ConvBnRelu
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class UResNet50_2(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, out_dim=128):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
replace_stride_with_dilation, norm_layer)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.level_conv = ConvBnRelu(3, 64)
'''
# For reference:
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
'''
uplayers = []
inplanes = 2048
first = True
div = [2,2,2,4,1]
for i in range(5):
uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first))
inplanes = inplanes // div[i]
first = False
self.uplayers = nn.ModuleList(uplayers)
self.tail = nn.Sequential(conv3x3(128, 64),
norm_layer(64),
nn.ReLU(),
conv1x1(64, out_dim))
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
def _forward_impl(self, x):
level = self.level_conv(x)
x0 = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x0)
x1 = checkpoint(self.layer1, x)
x2 = checkpoint(self.layer2, x1)
x3 = checkpoint(self.layer3, x2)
x4 = checkpoint(self.layer4, x3)
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
x = checkpoint(self.uplayers[0], x4)
x = checkpoint(self.uplayers[1], x, x3)
x = checkpoint(self.uplayers[2], x, x2)
x = checkpoint(self.uplayers[3], x, x1)
x = checkpoint(self.uplayers[4], x, x0)
return checkpoint(self.tail, torch.cat([x, level], dim=1))
def forward(self, x):
return self._forward_impl(x)
@register_model
def register_u_resnet50_2(opt_net, opt):
model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
if opt_get(opt_net, ['use_pretrained_base'], False):
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == '__main__':
model = UResNet50_2(Bottleneck, [3,4,6,3])
samp = torch.rand(1,3,224,224)
y = model(samp)
print(y.shape)
# For pixpro: attach to "tail.3"

View File

@ -1,86 +0,0 @@
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
from torchvision.models.utils import load_state_dict_from_url
import torchvision
from models.arch_util import ConvBnRelu
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class UResNet50_3(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, out_dim=128):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
replace_stride_with_dilation, norm_layer)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
'''
# For reference:
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
'''
uplayers = []
inplanes = 2048
first = True
for i in range(3):
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
inplanes = inplanes // 2
first = False
self.uplayers = nn.ModuleList(uplayers)
# These two variables are separated out and renamed so that I can re-use parameters from a pretrained resnet_unet2.
self.last_uplayer = ReverseBottleneck(256, 128, norm_layer=norm_layer, passthrough=True)
self.tail3 = nn.Sequential(conv1x1(192, 128),
norm_layer(128),
nn.ReLU(),
conv1x1(128, out_dim))
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
def _forward_impl(self, x):
x0 = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x0)
x1 = checkpoint(self.layer1, x)
x2 = checkpoint(self.layer2, x1)
x3 = checkpoint(self.layer3, x2)
x4 = checkpoint(self.layer4, x3)
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
x = checkpoint(self.uplayers[0], x4)
x = checkpoint(self.uplayers[1], x, x3)
x = checkpoint(self.uplayers[2], x, x2)
x = checkpoint(self.last_uplayer, x, x1)
return checkpoint(self.tail3, torch.cat([x, x0], dim=1))
def forward(self, x):
return self._forward_impl(x)
@register_model
def register_u_resnet50_3(opt_net, opt):
model = UResNet50_3(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
if opt_get(opt_net, ['use_pretrained_base'], False):
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == '__main__':
model = UResNet50_3(Bottleneck, [3,4,6,3])
samp = torch.rand(1,3,224,224)
y = model(samp)
print(y.shape)
# For pixpro: attach to "tail.3"

View File

@ -1,9 +0,0 @@
from models.styled_sr.discriminator import StyleSrGanDivergenceLoss
def create_stylesr_loss(opt_loss, env):
type = opt_loss['type']
if type == 'style_sr_gan_divergence_loss':
return StyleSrGanDivergenceLoss(opt_loss, env)
else:
raise NotImplementedError

View File

@ -1,344 +0,0 @@
# Heavily based on the lucidrains stylegan2 discriminator implementation.
import math
import os
from functools import partial
from math import log2
from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import grad as torch_grad
import trainer.losses as L
from vector_quantize_pytorch import VectorQuantize
from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
super().__init__()
self.filters = filters
self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)
self.net = nn.Sequential(
TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu(),
TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
leaky_relu()
)
self.downsample = nn.Sequential(
Blur(),
TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
) if downsample else None
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
x = self.downsample(x)
x = (x + res) * (1 / math.sqrt(2))
return x
class StyleSrDiscriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
transfer_mode=False):
super().__init__()
num_layers = int(log2(image_size) - 1)
blocks = []
filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
chan_in_out = list(zip(filters[:-1], filters[1:]))
blocks = []
attn_blocks = []
quantize_blocks = []
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
blocks.append(block)
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
attn_blocks.append(attn_fn)
if quantize:
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
quantize_blocks.append(quantize_fn)
else:
quantize_blocks.append(None)
self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
self.quantize_blocks = nn.ModuleList(quantize_blocks)
self.do_checkpointing = do_checkpointing
chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last
self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
self.flatten = nn.Flatten()
if mlp:
self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
leaky_relu(),
TransferLinear(100, 1, transfer_mode=transfer_mode))
else:
self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)
self._init_weights()
self.transfer_mode = transfer_mode
if transfer_mode:
for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
p.DO_NOT_TRAIN = True
def forward(self, x):
b, *_ = x.shape
quantize_loss = torch.zeros(1).to(x)
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
if self.do_checkpointing:
x = checkpoint(block, x)
else:
x = block(x)
if exists(attn_block):
x = attn_block(x)
if exists(q_block):
x, _, loss = q_block(x)
quantize_loss += loss
x = self.final_conv(x)
x = self.flatten(x)
x = self.to_logit(x)
if exists(q_block):
return x.squeeze(), quantize_loss
else:
return x.squeeze()
def _init_weights(self):
for m in self.modules():
if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# Configures the network as partially pre-trained. This means:
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
# 2) The head (linear layers) will also have their weights re-initialized
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
# These settings will be applied after the weights have been loaded (network_loaded())
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
self.bypass_blocks = bypass_blocks
self.num_blocks = num_blocks
self.frozen_until_step = frozen_until_step
# Called after the network weights are loaded.
def network_loaded(self):
if not hasattr(self, 'frozen_until_step'):
return
if self.bypass_blocks > 0:
self.blocks = self.blocks[self.bypass_blocks:]
self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device)
reset_blocks = [self.to_logit]
for i in range(self.num_blocks):
reset_blocks.append(self.blocks[i])
for bl in reset_blocks:
for m in bl.modules():
if type(m) in {TransferConv2d, TransferLinear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for p in m.parameters(recurse=True):
p._NEW_BLOCK = True
for p in self.parameters():
if not hasattr(p, '_NEW_BLOCK'):
p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step
# helper classes
def DiffAugment(x, types=[]):
for p in types:
for f in AUGMENT_FNS[p]:
x = f(x)
return x.contiguous()
def random_hflip(tensor, prob):
if prob > random():
return tensor
return torch.flip(tensor, dims=(3,))
def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
class DiscAugmentor(nn.Module):
def __init__(self, D, image_size, types, prob):
super().__init__()
self.D = D
self.prob = prob
self.types = types
def forward(self, images, real_images=False):
if random() < self.prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=self.types)
if real_images:
self.hq_aug = images.detach().clone()
else:
self.gen_aug = images.detach().clone()
# Save away for use elsewhere (e.g. unet loss)
self.aug_images = images
return self.D(images)
def network_loaded(self):
self.D.network_loaded()
# Allows visualizing what the augmentor is up to.
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step)))
torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step)))
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)
def gradient_penalty(images, output, weight=10, return_structured_grads=False):
batch_size = images.shape[0]
gradients = torch_grad(outputs=output, inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
flat_grad = gradients.reshape(batch_size, -1)
penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
if return_structured_grads:
return penalty, gradients
else:
return penalty
class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input, real_images=False)
if self.for_gen:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input, real_images=True)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False)
return divergence_loss
@register_model
def register_styledsr_discriminator(opt_net, opt):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False),
mlp=opt_get(opt_net, ['mlp_head'], True),
transfer_mode=opt_get(opt_net, ['transfer_mode'], False)
)
if 'use_partial_pretrained' in opt_net.keys():
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])

View File

@ -1,199 +0,0 @@
from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import kaiming_init
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock
from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
def rrdb_init_weights(module, scale=1):
for m in module.modules():
if isinstance(m, TransferConv2d):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
elif isinstance(m, TransferLinear):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
class EncoderRRDB(nn.Module):
def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False):
super(EncoderRRDB, self).__init__()
for i in range(5):
out_channels = output_channels if i == 4 else growth_channels
self.add_module(
f'conv{i+1}',
TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for i in range(5):
rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5
class StyledSrEncoder(nn.Module):
def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False):
super().__init__()
# Current assumes fea_out=256.
self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode)
self.rrdbs = nn.ModuleList([
EncoderRRDB(32, transfer_mode=transfer_mode),
EncoderRRDB(64, transfer_mode=transfer_mode),
EncoderRRDB(96, transfer_mode=transfer_mode),
EncoderRRDB(128, transfer_mode=transfer_mode),
EncoderRRDB(160, transfer_mode=transfer_mode),
EncoderRRDB(192, transfer_mode=transfer_mode),
EncoderRRDB(224, transfer_mode=transfer_mode)])
def forward(self, x):
fea = self.initial_conv(x)
for rrdb in self.rrdbs:
fea = torch.cat([fea, checkpoint(rrdb, fea)], dim=1)
return fea
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False):
super().__init__()
total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample.
self.image_size = image_size
self.scale = 2 ** upsample_levels
self.latent_dim = latent_dim
self.num_layers = total_levels
self.transfer_mode = transfer_mode
filters = [
512, # 4x4
512, # 8x8
512, # 16x16
256, # 32x32
128, # 64x64
64, # 128x128
32, # 256x256
16, # 512x512
8, # 1024x1024
]
# I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted.
self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False)
in_out_pairs = list(zip(filters[:-1], filters[1:]))
self.blocks = nn.ModuleList([])
for ind in range(start_level, start_level+total_levels):
in_chan, out_chan = in_out_pairs[ind]
not_first = ind != start_level
not_last = ind != (start_level+total_levels-1)
block = GeneratorBlock(
latent_dim,
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last,
transfer_learning_mode=transfer_mode
)
self.blocks.append(block)
def forward(self, lr, styles):
b, c, h, w = lr.shape
if self.transfer_mode:
with torch.no_grad():
x = self.encoder(lr)
else:
x = self.encoder(lr)
styles = styles.transpose(0, 1)
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
if h != x.shape[-2]:
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
else:
rgb = lr
for style, block in zip(styles, self.blocks):
x, rgb = checkpoint(block, x, rgb, style, input_noise)
return rgb
class StyledSrGenerator(nn.Module):
def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False):
super().__init__()
# Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
self.l2 = nn.MSELoss()
self.mixed_prob = .9
self._init_weights()
self.transfer_mode = transfer_mode
self.initial_stride = initial_stride
if transfer_mode:
for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
p.DO_NOT_TRAIN = True
def _init_weights(self):
for m in self.modules():
if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks:
nn.init.zeros_(block.to_noise1.weight)
nn.init.zeros_(block.to_noise2.weight)
nn.init.zeros_(block.to_noise1.bias)
nn.init.zeros_(block.to_noise2.bias)
def forward(self, x):
b, f, h, w = x.shape
# Synthesize style latents from noise.
style = torch.randn(b*2, self.gen.latent_dim).to(x.device)
if self.transfer_mode:
with torch.no_grad():
w = self.vectorizer(style)
else:
w = self.vectorizer(style)
# Randomly distribute styles across layers
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
for j in range(b):
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
w_styles[j] = w_styles[j*2]
else:
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
w_styles = w_styles[:b]
out = self.gen(x, w_styles)
# Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used
# for regularization.
out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area")
if self.initial_stride > 1:
x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area")
l2_reg = self.l2(x, out_down)
return out, l2_reg, w_styles
if __name__ == '__main__':
gen = StyledSrGenerator(128, 2)
out = gen(torch.rand(1,3,64,64))
print([o.shape for o in out])
@register_model
def register_styled_sr(opt_net, opt):
return StyledSrGenerator(128,
initial_stride=opt_get(opt_net, ['initial_stride'], 1),
transfer_mode=opt_get(opt_net, ['transfer_mode'], False))

View File

@ -1,411 +0,0 @@
import math
import multiprocessing
from contextlib import contextmanager, ExitStack
import torch
import torch.nn.functional as F
from kornia.filters import filter2D
from linear_attention_transformer import ImageLinearAttention
from torch import nn, Tensor
from torch.autograd import grad as torch_grad
from torch.nn import Parameter, init
from torch.nn.modules.conv import _ConvNd
from models.styled_sr.transfer_primitives import TransferLinear
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
num_cores = multiprocessing.cpu_count()
# constants
EPS = 1e-8
class NanException(Exception):
pass
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if not exists(old):
return new
return old * self.beta + (1 - self.beta) * new
class Flatten(nn.Module):
def forward(self, x):
return x.reshape(x.shape[0], -1)
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
class PermuteToFrom(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = x.permute(0, 2, 3, 1)
out, loss = self.fn(x)
out = out.permute(0, 3, 1, 2)
return out, loss
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f[None, :, None]
return filter2D(x, f, normalized=True)
# one layer of self-attention and feedforward, for images
attn_and_ff = lambda chan: nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
])
# helpers
def exists(val):
return val is not None
@contextmanager
def null_context():
yield
def combine_contexts(contexts):
@contextmanager
def multi_contexts():
with ExitStack() as stack:
yield [stack.enter_context(ctx()) for ctx in contexts]
return multi_contexts
def default(value, d):
return value if exists(value) else d
def cycle(iterable):
while True:
for i in iterable:
yield i
def cast_list(el):
return el if isinstance(el, list) else [el]
def is_empty(t):
if isinstance(t, torch.Tensor):
return t.nelement() == 0
return not exists(t)
def raise_if_nan(t):
if torch.isnan(t):
raise NanException
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
if is_ddp:
num_no_syncs = gradient_accumulate_every - 1
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
tail = [null_context]
contexts = head + tail
else:
contexts = [null_context] * gradient_accumulate_every
for context in contexts:
with context():
yield
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)
def calc_pl_lengths(styles, images):
device = images.device
num_pixels = images.shape[2] * images.shape[3]
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
outputs = (images * pl_noise).sum()
pl_grads = torch_grad(outputs=outputs, inputs=styles,
grad_outputs=torch.ones(outputs.shape, device=device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
def image_noise(n, im_size, device):
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device)
def leaky_relu(p=0.2):
return nn.LeakyReLU(p, inplace=True)
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
def set_requires_grad(model, bool):
for p in model.parameters():
p.requires_grad = bool
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim))
self.lr_mul = lr_mul
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def forward(self, input):
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul)
class StyleVectorizer(nn.Module):
def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False):
super().__init__()
layers = []
for i in range(depth):
layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()])
self.net = nn.Sequential(*layers)
def forward(self, x):
x = F.normalize(x, dim=1)
return self.net(x)
class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False):
super().__init__()
self.input_channel = input_channel
self.to_style = nn.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
Blur()
) if upsample else None
def forward(self, x, prev_rgb, istyle):
b, c, h, w = x.shape
style = self.to_style(istyle)
x = self.conv(x, style)
if exists(prev_rgb):
x = x + prev_rgb
if exists(self.upsample):
x = self.upsample(x)
return x
class AdaptiveInstanceNorm(nn.Module):
def __init__(self, in_channel, style_dim):
super().__init__()
from models.archs.arch_util import ConvGnLelu
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
self.norm = nn.InstanceNorm2d(in_channel)
def forward(self, input, style):
gamma = self.style2scale(style)
beta = self.style2bias(style)
out = self.norm(input)
out = gamma * out + beta
return out
class NoiseInjection(nn.Module):
def __init__(self, channel):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
def forward(self, image, noise):
return image + self.weight * noise
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * math.sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight = self.compute_weight(module)
setattr(module, self.name, weight)
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class Conv2DMod(nn.Module):
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs):
super().__init__()
self.filters = out_chan
self.demod = demod
self.kernel = kernel
self.stride = stride
self.dilation = dilation
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def _get_same_padding(self, size, kernel, dilation, stride):
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
def forward(self, x, y):
b, c, h, w = x.shape
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
w1 = y[:, None, :, None, None]
w2 = weight[None, :, :, :, :]
weights = w2 * (w1 + 1)
if self.demod:
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
weights = weights * d
x = x.reshape(1, -1, h, w)
_, _, *ws = weights.shape
weights = weights.reshape(b * self.filters, *ws)
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
x = F.conv2d(x, weights, padding=padding, groups=b)
x = x.reshape(-1, self.filters, h, w)
return x
class GeneratorBlock(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False,
transfer_learning_mode=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode)
self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode)
self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode)
self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode)
self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode)
self.transfer_learning_mode = transfer_learning_mode
def forward(self, x, prev_rgb, istyle, inoise):
if exists(self.upsample):
x = self.upsample(x)
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb

View File

@ -1,136 +0,0 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter, init
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _ntuple
_pair = _ntuple(2)
class TransferConv2d(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
transfer_mode: bool = False
):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def _conv_forward(self, input, weight):
if self.transfer_mode:
weight = weight * self.transfer_scale + self.transfer_shift
else:
weight = weight
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
class TransferLinear(nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.transfer_mode = transfer_mode
if transfer_mode:
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
self.transfer_scale.FOR_TRANSFER_LEARNING = True
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
self.transfer_shift.FOR_TRANSFER_LEARNING = True
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
if self.transfer_mode:
weight = self.weight * self.transfer_scale + self.transfer_shift
else:
weight = self.weight
return F.linear(input, weight, self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class TransferConvGnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False):
super().__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode)
if norm:
self.gn = nn.GroupNorm(num_groups, filters_out)
else:
self.gn = None
if activation:
self.lelu = nn.LeakyReLU(negative_slope=.2)
else:
self.lelu = None
# Init params.
for m in self.modules():
if isinstance(m, TransferConv2d):
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
nonlinearity='leaky_relu' if self.lelu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.gn:
x = self.gn(x)
if self.lelu:
return self.lelu(x)
else:
return x

View File

@ -1,15 +0,0 @@
import munch
import torch
from trainer.networks import register_model
@register_model
def register_flownet2(opt_net):
from models.flownet2.models import FlowNet2
ld = 'load_path' in opt_net.keys()
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
netG = FlowNet2(args)
if ld:
sd = torch.load(opt_net['load_path'])
netG.load_state_dict(sd['state_dict'])

View File

@ -1,79 +0,0 @@
import os
import torch
import torch.nn as nn
import torchvision
from trainer.networks import register_model
from utils.util import sequential_checkpoint
from models.arch_util import ConvGnSilu, make_layer
class TecoResblock(nn.Module):
def __init__(self, nf):
super(TecoResblock, self).__init__()
self.nf = nf
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
return identity + x
class TecoUpconv(nn.Module):
def __init__(self, nf, scale):
super(TecoUpconv, self).__init__()
self.nf = nf
self.scale = scale
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
x = self.conv3(x)
return self.final_conv(x)
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
# Main differences:
# - Uses SiLU instead of ReLU
# - Reference input is in HR space (just makes more sense)
# - Doesn't use transposed convolutions - just uses interpolation instead.
# - Upsample block is slightly more complicated.
class TecoGen(nn.Module):
def __init__(self, nf, scale):
super(TecoGen, self).__init__()
self.nf = nf
self.scale = scale
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
res_layers = [TecoResblock(nf) for i in range(15)]
upsample = TecoUpconv(nf, scale)
everything = [fea_conv] + res_layers + [upsample]
self.core = nn.Sequential(*everything)
def forward(self, x, ref=None):
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x)
join = torch.cat([x, ref], dim=1)
join = sequential_checkpoint(self.core, 6, join)
self.join = join.detach().clone() + .5
return x + join
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,)))
def get_debug_values(self, step, net_name):
return {'branch_std': self.join.std()}
@register_model
def register_tecogen(opt_net, opt):
return TecoGen(opt_net['nf'], opt_net['scale'])

View File

@ -1,155 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from trainer.inject import Injector
from trainer.networks import register_model
from utils.util import checkpoint
def create_injector(opt, env):
type = opt['type']
if type == 'igpt_resolve':
return ResolveInjector(opt, env)
return None
class ResolveInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.gen = opt['generator']
self.samples = opt['num_samples']
self.temperature = opt['temperature']
def forward(self, state):
gen = self.env['generators'][self.opt['generator']].module
img = state[self.input]
b, c, h, w = img.shape
qimg = gen.quantize(img)
s, b = qimg.shape
qimg = qimg[:s//2, :]
output = qimg.repeat(1, self.samples)
pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output
with torch.no_grad():
for _ in range(s//2):
logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
logits = logits[-1, :, :] / self.temperature
probs = F.softmax(logits, dim=-1)
pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
output = torch.cat((output, pred), dim=0)
output = gen.unquantize(output.reshape(h, w, -1))
return {self.output: output.permute(2,3,0,1).contiguous()}
class Block(nn.Module):
def __init__(self, embed_dim, num_heads):
super(Block, self).__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
def forward(self, x):
attn_mask = torch.full(
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
)
attn_mask = torch.triu(attn_mask, diagonal=1)
x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
class iGPT2(nn.Module):
def __init__(
self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
):
super().__init__()
self.centroids = nn.Parameter(
torch.from_numpy(np.load(centroids_file)), requires_grad=False
)
self.embed_dim = embed_dim
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(Block(embed_dim, num_heads))
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier.
def squared_euclidean_distance(self, a, b):
b = torch.transpose(b, 0, 1)
a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
ab = torch.matmul(a, b)
d = a2 - 2 * ab + b2
return d
def quantize(self, x):
b, c, h, w = x.shape
# [B, C, H, W] => [B, H, W, C]
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(-1, c) # flatten to pixels
d = self.squared_euclidean_distance(x, self.centroids)
x = torch.argmin(d, 1)
x = x.view(b, h, w)
# Reshape output to [seq_len, batch].
x = x.view(x.shape[0], -1) # flatten images into sequences
x = x.transpose(0, 1).contiguous() # to shape [seq len, batch]
return x
def unquantize(self, x):
return self.centroids[x]
def forward(self, x, already_quantized=False):
"""
Expect input as shape [b, c, h, w]
"""
if not already_quantized:
x = self.quantize(x)
length, batch = x.shape
h = self.token_embeddings(x)
# prepend sos token
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
h = torch.cat([sos, h[:-1, :, :]], axis=0)
# add positional embeddings
positions = torch.arange(length, device=x.device).unsqueeze(-1)
h = h + self.position_embeddings(positions).expand_as(h)
# transformer
for layer in self.layers:
h = checkpoint(layer, h)
h = self.ln_f(h)
logits = self.head(h)
return logits, x
@register_model
def register_igpt2(opt_net, opt):
return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2,
opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])

View File

@ -1,42 +0,0 @@
import numpy
import torch
from torch.utils.data import DataLoader
from data.torch_dataset import TorchDataset
from models.classifiers.cifar_resnet_branched import ResNet
from models.classifiers.cifar_resnet_branched import BasicBlock
if __name__ == '__main__':
dopt = {
'flip': True,
'crop_sz': None,
'dataset': 'cifar100',
'image_size': 32,
'normalize': False,
'kwargs': {
'root': 'E:\\4k6k\\datasets\\images\\cifar100',
'download': True
}
}
set = TorchDataset(dopt)
loader = DataLoader(set, num_workers=0, batch_size=32)
model = ResNet(BasicBlock, [2, 2, 2, 2])
model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardw_10000.pth'))
model.eval()
bins = [[] for _ in range(8)]
for i, batch in enumerate(loader):
logits, selector = model(batch['hq'], coarse_label=None, return_selector=True)
for k, s in enumerate(selector):
for j, b in enumerate(s):
if b:
bins[j].append(batch['labels'][k].item())
if i > 10:
break
import matplotlib.pyplot as plt
fig, axs = plt.subplots(3,3)
for i in range(8):
axs[i%3, i//3].hist(numpy.asarray(bins[i]))
plt.show()
print('hi')

View File

@ -1,70 +0,0 @@
import torch
import numpy as np
from utils import options as option
from data import create_dataloader, create_dataset
import math
from tqdm import tqdm
from utils.fdpl_util import dct_2d, extract_patches_2d
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from utils.colors import rgb2ycbcr
import torch.nn.functional as F
input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml"
output_file = "fdpr_diff_means.pt"
device = 'cuda'
patch_size=128
if __name__ == '__main__':
opt = option.parse(input_config, is_train=True)
opt['dist'] = False
# Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML.
dataset_opt = opt['datasets']['train']
train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
total_epochs = int(math.ceil(total_iters / train_size))
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
print('Number of train images: {:,d}, iters: {:,d}'.format(
len(train_set), train_size))
# calculate the perceptual weights
master_diff = np.zeros((patch_size, patch_size))
num_patches = 0
all_diff_patches = []
tq = tqdm(train_loader)
sampled = 0
for train_data in tq:
if sampled > 200:
break
sampled += 1
im = rgb2ycbcr(train_data['hq'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
size=im.shape[2:],
mode="bicubic", align_corners=False))
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
patches_hr = dct_2d(patches_hr, norm='ortho')
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
patches_lr = dct_2d(patches_lr, norm='ortho')
b, p, c, w, h = patches_hr.shape
diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001)
num_patches += b * p
all_diff_patches.append(torch.sum(diffs, dim=(0, 1)))
diff_patches = torch.stack(all_diff_patches, dim=0)
diff_means = torch.sum(diff_patches, dim=0) / num_patches
torch.save(diff_means, output_file)
print(diff_means)
for i in range(3):
fig, ax = plt.subplots()
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
im = ax.imshow(diff_means[i].numpy())
ax.set_title("mean_diff for channel %i" % (i,))
fig.colorbar(im, cax=cax, orientation='vertical')
plt.show()

View File

@ -1,411 +0,0 @@
"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
import sys
import os.path as osp
import glob
import pickle
from multiprocessing import Pool
import numpy as np
import lmdb
import cv2
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
import data.util as data_util # noqa: E402
import utils.util as util # noqa: E402
def main():
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
mode = 'hq' # used for vimeo90k and REDS datasets
# vimeo90k: GT | LR | flow
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
# train_sharp_flowx4
if dataset == 'vimeo90k':
vimeo90k(mode)
elif dataset == 'REDS':
REDS(mode)
elif dataset == 'general':
opt = {}
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['name'] = 'DIV2K800_sub_GT'
general_image_folder(opt)
elif dataset == 'DIV2K_demo':
opt = {}
## GT
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['name'] = 'DIV2K800_sub_GT'
general_image_folder(opt)
## LR
opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
opt['name'] = 'DIV2K800_sub_bicLRx4'
general_image_folder(opt)
elif dataset == 'test':
test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS')
def read_image_worker(path, key):
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
return (key, img)
def general_image_folder(opt):
"""Create lmdb for general image folders
Users should define the keys, such as: '0321_s035' for DIV2K sub-images
If all the images have the same resolution, it will only store one copy of resolution info.
Otherwise, it will store every resolution info.
"""
#### configurations
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
n_thread = 40
########################################################
img_folder = opt['img_folder']
lmdb_save_path = opt['lmdb_save_path']
meta_info = {'name': opt['name']}
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
if osp.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
#### read all the image paths to a list
print('Reading image path list ...')
all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
keys = []
for img_path in all_img_list:
keys.append(osp.splitext(osp.basename(img_path))[0])
if read_all_imgs:
#### read all images to memory (multiprocessing)
dataset = {} # store all image data. list cannot keep the order, use dict
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
pbar = util.ProgressBar(len(all_img_list))
def mycallback(arg):
'''get the image data and update pbar'''
key = arg[0]
dataset[key] = arg[1]
pbar.update('Reading {}'.format(key))
pool = Pool(n_thread)
for path, key in zip(all_img_list, keys):
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
pool.close()
pool.join()
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
#### create lmdb environment
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
#### write data to lmdb
pbar = util.ProgressBar(len(all_img_list))
txn = env.begin(write=True)
resolutions = []
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
pbar.update('Write {}'.format(key))
key_byte = key.encode('ascii')
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
if data.ndim == 2:
H, W = data.shape
C = 1
else:
H, W, C = data.shape
txn.put(key_byte, data)
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
if not read_all_imgs and idx % BATCH == 0:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
#### create meta information
# check whether all the images are the same size
assert len(keys) == len(resolutions)
if len(set(resolutions)) <= 1:
meta_info['resolution'] = [resolutions[0]]
meta_info['keys'] = keys
print('All images have the same resolution. Simplify the meta info.')
else:
meta_info['resolution'] = resolutions
meta_info['keys'] = keys
print('Not all images have the same resolution. Save meta info for each image.')
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def vimeo90k(mode):
"""Create lmdb for the Vimeo90K dataset, each image with a fixed size
GT: [3, 256, 448]
Now only need the 4th frame, e.g., 00001_0001_4
LR: [3, 64, 112]
1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
key:
Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
Flow map is quantized by mmcv and saved in png format
"""
#### configurations
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'hq':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
H_dst, W_dst = 256, 448
elif mode == 'LR':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
H_dst, W_dst = 64, 112
elif mode == 'flow':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
H_dst, W_dst = 128, 112
else:
raise ValueError('Wrong dataset mode: {}'.format(mode))
n_thread = 40
########################################################
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
if osp.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
#### read all the image paths to a list
print('Reading image path list ...')
with open(txt_file) as f:
train_l = f.readlines()
train_l = [v.strip() for v in train_l]
all_img_list = []
keys = []
for line in train_l:
folder = line.split('/')[0]
sub_folder = line.split('/')[1]
all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
if mode == 'flow':
for j in range(1, 4):
keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
else:
for j in range(7):
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
all_img_list = sorted(all_img_list)
keys = sorted(keys)
if mode == 'hq': # only read the 4th frame for the GT mode
print('Only keep the 4th frame.')
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
keys = [v for v in keys if v.endswith('_4')]
if read_all_imgs:
#### read all images to memory (multiprocessing)
dataset = {} # store all image data. list cannot keep the order, use dict
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
pbar = util.ProgressBar(len(all_img_list))
def mycallback(arg):
"""get the image data and update pbar"""
key = arg[0]
dataset[key] = arg[1]
pbar.update('Reading {}'.format(key))
pool = Pool(n_thread)
for path, key in zip(all_img_list, keys):
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
pool.close()
pool.join()
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
#### write data to lmdb
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
txn = env.begin(write=True)
pbar = util.ProgressBar(len(all_img_list))
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
pbar.update('Write {}'.format(key))
key_byte = key.encode('ascii')
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
if 'flow' in mode:
H, W = data.shape
assert H == H_dst and W == W_dst, 'different shape.'
else:
H, W, C = data.shape
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
txn.put(key_byte, data)
if not read_all_imgs and idx % BATCH == 0:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
#### create meta information
meta_info = {}
if mode == 'hq':
meta_info['name'] = 'Vimeo90K_train_GT'
elif mode == 'lq':
meta_info['name'] = 'Vimeo90K_train_LR'
elif mode == 'flow':
meta_info['name'] = 'Vimeo90K_train_flowx4'
channel = 1 if 'flow' in mode else 3
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
key_set = set()
for key in keys:
if mode == 'flow':
a, b, _, _ = key.split('_')
else:
a, b, _ = key.split('_')
key_set.add('{}_{}'.format(a, b))
meta_info['keys'] = list(key_set)
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def REDS(mode):
"""Create lmdb for the REDS dataset, each image with a fixed size
GT: [3, 720, 1280], key: 000_00000000
LR: [3, 180, 320], key: 000_00000000
key: 000_00000000
flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
Flow map is quantized by mmcv and saved in png format
"""
#### configurations
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'train_sharp':
img_folder = '../../datasets/REDS/train_sharp'
lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
H_dst, W_dst = 720, 1280
elif mode == 'train_sharp_bicubic':
img_folder = '../../datasets/REDS/train_sharp_bicubic'
lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
H_dst, W_dst = 180, 320
elif mode == 'train_blur_bicubic':
img_folder = '../../datasets/REDS/train_blur_bicubic'
lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
H_dst, W_dst = 180, 320
elif mode == 'train_blur':
img_folder = '../../datasets/REDS/train_blur'
lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
H_dst, W_dst = 720, 1280
elif mode == 'train_blur_comp':
img_folder = '../../datasets/REDS/train_blur_comp'
lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
H_dst, W_dst = 720, 1280
elif mode == 'train_sharp_flowx4':
img_folder = '../../datasets/REDS/train_sharp_flowx4'
lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
H_dst, W_dst = 360, 320
n_thread = 40
########################################################
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
if osp.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
#### read all the image paths to a list
print('Reading image path list ...')
all_img_list = data_util._get_paths_from_images(img_folder)
keys = []
for img_path in all_img_list:
split_rlt = img_path.split('/')
folder = split_rlt[-2]
img_name = split_rlt[-1].split('.png')[0]
keys.append(folder + '_' + img_name)
if read_all_imgs:
#### read all images to memory (multiprocessing)
dataset = {} # store all image data. list cannot keep the order, use dict
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
pbar = util.ProgressBar(len(all_img_list))
def mycallback(arg):
'''get the image data and update pbar'''
key = arg[0]
dataset[key] = arg[1]
pbar.update('Reading {}'.format(key))
pool = Pool(n_thread)
for path, key in zip(all_img_list, keys):
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
pool.close()
pool.join()
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
#### create lmdb environment
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
#### write data to lmdb
pbar = util.ProgressBar(len(all_img_list))
txn = env.begin(write=True)
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
pbar.update('Write {}'.format(key))
key_byte = key.encode('ascii')
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
if 'flow' in mode:
H, W = data.shape
assert H == H_dst and W == W_dst, 'different shape.'
else:
H, W, C = data.shape
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
txn.put(key_byte, data)
if not read_all_imgs and idx % BATCH == 0:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
#### create meta information
meta_info = {}
meta_info['name'] = 'REDS_{}_wval'.format(mode)
channel = 1 if 'flow' in mode else 3
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
meta_info['keys'] = keys
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def test_lmdb(dataroot, dataset='REDS'):
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
print('Name: ', meta_info['name'])
print('Resolution: ', meta_info['resolution'])
print('# keys: ', len(meta_info['keys']))
# read one image
if dataset == 'vimeo90k':
key = '00001_0001_4'
else:
key = '000_00000000'
print('Reading {} for test.'.format(key))
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
img = img_flat.reshape(H, W, C)
cv2.imwrite('test.png', img)
if __name__ == "__main__":
main()

View File

@ -1,22 +0,0 @@
from torch.utils.tensorboard import SummaryWriter
if __name__ == "__main__":
writer = SummaryWriter("../experiments/recovered_tb")
f = open("../experiments/recovered_tb.txt", encoding="utf8")
console = f.readlines()
search_terms = [
("iter", ", iter: ", ", lr:"),
("l_g_total", " l_g_total: ", " switch_temperature:"),
("l_d_fake", "l_d_fake: ", " D_fake:")
]
iter = 0
for line in console:
if " - INFO: [epoch:" not in line:
continue
for name, start, end in search_terms:
val = line[line.find(start)+len(start):line.find(end)].replace(",", "")
if name == "iter":
iter = int(val)
else:
writer.add_scalar(name, float(val), iter)
writer.close()

View File

@ -1,19 +0,0 @@
import os
import glob
def main():
folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4'
DIV2K(folder)
print('Finished.')
def DIV2K(path):
img_path_l = glob.glob(os.path.join(path, '*'))
for img_path in img_path_l:
new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
os.rename(img_path, new_path)
if __name__ == "__main__":
main()

View File

@ -1,83 +0,0 @@
import os.path as osp
import logging
import time
import argparse
import os
import torchvision
import utils
import utils.options as option
import utils.util as util
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
if __name__ == "__main__":
#### options
torch.backends.cudnn.benchmark = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_psnr_approximator.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
dataset_opt['n_workers'] = 0
test_set = create_dataset(dataset_opt)
test_loader = create_dataloader(test_set, dataset_opt, opt)
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = ExtensibleTrainer(opt)
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
logger.info('\nTesting [{:s}]...'.format(test_set_name))
test_start_time = time.time()
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
util.mkdir(dataset_dir)
dst_path = "F:\\playground"
[os.makedirs(osp.join(dst_path, str(i)), exist_ok=True) for i in range(10)]
corruptions = ['none', 'color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise',
'jpeg-medium', 'jpeg-broad', 'jpeg-normal', 'saturation', 'lq_resampling',
'lq_resampling4x']
c_counter = 0
test_set.corruptor.num_corrupts = 0
test_set.corruptor.random_corruptions = []
test_set.corruptor.fixed_corruptions = [corruptions[0]]
corruption_mse = [(0,0) for _ in corruptions]
tq = tqdm(test_loader)
batch_size = opt['datasets']['train']['batch_size']
for data in tq:
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
model.feed_data(data, need_GT=need_GT)
model.test()
est_psnr = torch.mean(model.eval_state['psnr_approximate'][0], dim=[1,2,3])
for i in range(est_psnr.shape[0]):
im_path = data['GT_path'][i]
torchvision.utils.save_image(model.eval_state['lq'][0][i], osp.join(dst_path, str(int(est_psnr[i]*10)), osp.basename(im_path)))
#shutil.copy(im_path, osp.join(dst_path, str(int(est_psnr[i]*10))))
last_mse, last_ctr = corruption_mse[c_counter % len(corruptions)]
corruption_mse[c_counter % len(corruptions)] = (last_mse + torch.sum(est_psnr).item(), last_ctr + 1)
c_counter += 1
test_set.corruptor.fixed_corruptions = [corruptions[c_counter % len(corruptions)]]
if c_counter % 100 == 0:
for i, (mse, ctr) in enumerate(corruption_mse):
print("%s: %f" % (corruptions[i], mse / (ctr * batch_size)))

View File

@ -1,136 +0,0 @@
import numpy as np
import torch
import torch.nn as nn
# note: all dct related functions are either exactly as or based on those
# at https://github.com/zh217/torch-dct
def dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
Vc = torch.rfft(v, 1, onesided=False)
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
if norm == 'ortho':
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V.view(*x_shape)
return V
def idct(X, norm=None):
"""
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct(dct(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DCT-II of the signal over the last dimension
"""
x_shape = X.shape
N = x_shape[-1]
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
if norm == 'ortho':
X_v[:, 0] *= np.sqrt(N) * 2
X_v[:, 1:] *= np.sqrt(N / 2) * 2
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V_t_r = X_v
V_t_r = V_t_r.to(device)
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
V_t_i = V_t_i.to(device)
V_r = V_t_r * W_r - V_t_i * W_i
V_i = V_t_r * W_i + V_t_i * W_r
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
v = torch.irfft(V, 1, onesided=False)
x = v.new_zeros(v.shape)
x[:, ::2] += v[:, :N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, :N // 2]
return x.view(*x_shape)
def dct_2d(x, norm=None):
"""
2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 2 dimensions
"""
X1 = dct(x, norm=norm)
X2 = dct(X1.transpose(-1, -2), norm=norm)
return X2.transpose(-1, -2)
def idct_2d(X, norm=None):
"""
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct_2d(dct_2d(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 2 dimensions
"""
x1 = idct(X, norm=norm)
x2 = idct(x1.transpose(-1, -2), norm=norm)
return x2.transpose(-1, -2)
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False):
"""
source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78
"""
patch_H, patch_W = patch_shape[0], patch_shape[1]
if(img.size(2) < patch_H):
num_padded_H_Top = (patch_H - img.size(2))//2
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0)
img = padding_H(img)
if(img.size(3) < patch_W):
num_padded_W_Left = (patch_W - img.size(3))//2
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0)
img = padding_W(img)
step_int = [0, 0]
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
patches_fold_H = img.unfold(2, patch_H, step_int[0])
if((img.size(2) - patch_H) % step_int[0] != 0):
patches_fold_H = torch.cat((patches_fold_H,
img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2)
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
if((img.size(3) - patch_W) % step_int[1] != 0):
patches_fold_HW = torch.cat((patches_fold_HW,
patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3)
patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5)
patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W)
if(batch_first):
patches = patches.permute(1, 0, 2, 3, 4)
return patches