diff --git a/codes/models/image_generation/lightweight_gan.py b/codes/models/image_generation/lightweight_gan.py deleted file mode 100644 index fc405086..00000000 --- a/codes/models/image_generation/lightweight_gan.py +++ /dev/null @@ -1,914 +0,0 @@ -import math -import multiprocessing -import random -from contextlib import contextmanager, ExitStack -from functools import partial -from math import log2, floor -from pathlib import Path -from random import random - -import torch -import torch.nn.functional as F -from gsa_pytorch import GSA - -import trainer.losses as L -import torchvision -from PIL import Image -from einops import rearrange, reduce -from kornia import filter2d -from torch import nn, einsum -from torch.utils.data import Dataset -from torchvision import transforms - -from models.image_generation.stylegan.stylegan2_lucidrains import gradient_penalty -from trainer.networks import register_model -from utils.util import opt_get - - -def DiffAugment(x, types=[]): - for p in types: - for f in AUGMENT_FNS[p]: - x = f(x) - return x.contiguous() - - -# """ -# Augmentation functions got images as `x` -# where `x` is tensor with this dimensions: -# 0 - count of images -# 1 - channels -# 2 - width -# 3 - height of image -# """ - -def rand_brightness(x): - x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) - return x - -def rand_saturation(x): - x_mean = x.mean(dim=1, keepdim=True) - x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean - return x - -def rand_contrast(x): - x_mean = x.mean(dim=[1, 2, 3], keepdim=True) - x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean - return x - -def rand_translation(x, ratio=0.125): - shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) - translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) - translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) - grid_batch, grid_x, grid_y = torch.meshgrid( - torch.arange(x.size(0), dtype=torch.long, device=x.device), - torch.arange(x.size(2), dtype=torch.long, device=x.device), - torch.arange(x.size(3), dtype=torch.long, device=x.device), - ) - grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) - grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) - x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) - x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) - return x - -def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1): - w, h = x.size(2), x.size(3) - - imgs = [] - for img in x.unbind(dim = 0): - max_h = int(w * ratio * ratio_h) - max_v = int(h * ratio * ratio_v) - - value_h = random.randint(0, max_h) * 2 - max_h - value_v = random.randint(0, max_v) * 2 - max_v - - if abs(value_h) > 0: - img = torch.roll(img, value_h, 2) - - if abs(value_v) > 0: - img = torch.roll(img, value_v, 1) - - imgs.append(img) - - return torch.stack(imgs) - -def rand_offset_h(x, ratio=1): - return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0) - -def rand_offset_v(x, ratio=1): - return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio) - -def rand_cutout(x, ratio=0.5): - cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) - offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) - offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) - grid_batch, grid_x, grid_y = torch.meshgrid( - torch.arange(x.size(0), dtype=torch.long, device=x.device), - torch.arange(cutout_size[0], dtype=torch.long, device=x.device), - torch.arange(cutout_size[1], dtype=torch.long, device=x.device), - ) - grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) - grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) - mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) - mask[grid_batch, grid_x, grid_y] = 0 - x = x * mask.unsqueeze(1) - return x - -AUGMENT_FNS = { - 'color': [rand_brightness, rand_saturation, rand_contrast], - 'offset': [rand_offset], - 'offset_h': [rand_offset_h], - 'offset_v': [rand_offset_v], - 'translation': [rand_translation], - 'cutout': [rand_cutout], -} - -# constants - -NUM_CORES = multiprocessing.cpu_count() -EXTS = ['jpg', 'jpeg', 'png'] - - -# helpers - -def exists(val): - return val is not None - - -@contextmanager -def null_context(): - yield - - -def combine_contexts(contexts): - @contextmanager - def multi_contexts(): - with ExitStack() as stack: - yield [stack.enter_context(ctx()) for ctx in contexts] - - return multi_contexts - - -def is_power_of_two(val): - return log2(val).is_integer() - - -def default(val, d): - return val if exists(val) else d - - -def set_requires_grad(model, bool): - for p in model.parameters(): - p.requires_grad = bool - - -def cycle(iterable): - while True: - for i in iterable: - yield i - - -def raise_if_nan(t): - if torch.isnan(t): - raise NanException - - -def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): - if is_ddp: - num_no_syncs = gradient_accumulate_every - 1 - head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs - tail = [null_context] - contexts = head + tail - else: - contexts = [null_context] * gradient_accumulate_every - - for context in contexts: - with context(): - yield - - -def hinge_loss(real, fake): - return (F.relu(1 + real) + F.relu(1 - fake)).mean() - - -def evaluate_in_chunks(max_batch_size, model, *args): - split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) - chunked_outputs = [model(*i) for i in split_args] - if len(chunked_outputs) == 1: - return chunked_outputs[0] - return torch.cat(chunked_outputs, dim=0) - - -def slerp(val, low, high): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - -def safe_div(n, d): - try: - res = n / d - except ZeroDivisionError: - prefix = '' if int(n >= 0) else '-' - res = float(f'{prefix}inf') - return res - - -# helper classes - -class NanException(Exception): - pass - - -class EMA(): - def __init__(self, beta): - super().__init__() - self.beta = beta - - def update_average(self, old, new): - if not exists(old): - return new - return old * self.beta + (1 - self.beta) * new - - -class EMAWrapper(nn.Module): - def __init__(self, wrapped_module, following_module, rate=.995, steps_per_ema=10, steps_per_reset=1000, steps_after_no_reset=25000, reset=True): - super().__init__() - self.wrapped = wrapped_module - self.following = following_module - self.ema_updater = EMA(rate) - self.steps_per_ema = steps_per_ema - self.steps_per_reset = steps_per_reset - self.steps_after_no_reset = steps_after_no_reset - if reset: - self.wrapped.load_state_dict(self.following.state_dict()) - for p in self.wrapped.parameters(): - p.DO_NOT_TRAIN = True - - def reset_parameter_averaging(self): - self.wrapped.load_state_dict(self.following.state_dict()) - - def update_moving_average(self): - for current_params, ma_params in zip(self.following.parameters(), self.wrapped.parameters()): - old_weight, up_weight = ma_params.data, current_params.data - ma_params.data = self.ema_updater.update_average(old_weight, up_weight) - - for current_buffer, ma_buffer in zip(self.following.buffers(), self.wrapped.buffers()): - new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer) - ma_buffer.copy_(new_buffer_value) - - def after_step(self, step): - if step % self.steps_per_ema == 0: - self.update_moving_average() - if step % self.steps_per_reset and step < self.steps_after_no_reset: - self.reset_parameter_averaging() - - def forward(self, x): - with torch.no_grad(): - return self.wrapped(x) - - -class RandomApply(nn.Module): - def __init__(self, prob, fn, fn_else=lambda x: x): - super().__init__() - self.fn = fn - self.fn_else = fn_else - self.prob = prob - - def forward(self, x): - fn = self.fn if random() < self.prob else self.fn_else - return fn(x) - - -class Rezero(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - self.g = nn.Parameter(torch.tensor(1e-3)) - - def forward(self, x): - return self.g * self.fn(x) - - -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, x): - return self.fn(x) + x - - -class SumBranches(nn.Module): - def __init__(self, branches): - super().__init__() - self.branches = nn.ModuleList(branches) - - def forward(self, x): - return sum(map(lambda fn: fn(x), self.branches)) - - -class Blur(nn.Module): - def __init__(self): - super().__init__() - f = torch.Tensor([1, 2, 1]) - self.register_buffer('f', f) - - def forward(self, x): - f = self.f - f = f[None, None, :] * f[None, :, None] - return filter2d(x, f, normalized=True) - - -# dataset - -def convert_image_to(img_type, image): - if image.mode != img_type: - return image.convert(img_type) - return image - - -class identity(object): - def __call__(self, tensor): - return tensor - - -class expand_greyscale(object): - def __init__(self, transparent): - self.transparent = transparent - - def __call__(self, tensor): - channels = tensor.shape[0] - num_target_channels = 4 if self.transparent else 3 - - if channels == num_target_channels: - return tensor - - alpha = None - if channels == 1: - color = tensor.expand(3, -1, -1) - elif channels == 2: - color = tensor[:1].expand(3, -1, -1) - alpha = tensor[1:] - else: - raise Exception(f'image with invalid number of channels given {channels}') - - if not exists(alpha) and self.transparent: - alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) - - return color if not self.transparent else torch.cat((color, alpha)) - - -def resize_to_minimum_size(min_size, image): - if max(*image.size) < min_size: - return torchvision.transforms.functional.resize(image, min_size) - return image - - -class ImageDataset(Dataset): - def __init__( - self, - folder, - image_size, - transparent=False, - greyscale=False, - aug_prob=0. - ): - super().__init__() - self.folder = folder - self.image_size = image_size - self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')] - assert len(self.paths) > 0, f'No images were found in {folder} for training' - - if transparent: - num_channels = 4 - pillow_mode = 'RGBA' - expand_fn = expand_greyscale(transparent) - elif greyscale: - num_channels = 1 - pillow_mode = 'L' - expand_fn = identity() - else: - num_channels = 3 - pillow_mode = 'RGB' - expand_fn = expand_greyscale(transparent) - - convert_image_fn = partial(convert_image_to, pillow_mode) - - self.transform = transforms.Compose([ - transforms.Lambda(convert_image_fn), - transforms.Lambda(partial(resize_to_minimum_size, image_size)), - transforms.Resize(image_size), - RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), - transforms.CenterCrop(image_size)), - transforms.ToTensor(), - transforms.Lambda(expand_fn) - ]) - - def __len__(self): - return len(self.paths) - - def __getitem__(self, index): - path = self.paths[index] - img = Image.open(path) - return self.transform(img) - - -# augmentations - -def random_hflip(tensor, prob): - if prob > random(): - return tensor - return torch.flip(tensor, dims=(3,)) - - -class AugWrapper(nn.Module): - def __init__(self, D, image_size, prob, types): - super().__init__() - self.D = D - self.prob = prob - self.types = types - - def forward(self, images, detach=False, **kwargs): - context = torch.no_grad if detach else null_context - - with context(): - if random() < self.prob: - images = random_hflip(images, prob=0.5) - images = DiffAugment(images, types=self.types) - - return self.D(images, **kwargs) - - -# modifiable global variables - -norm_class = nn.BatchNorm2d - - -def upsample(scale_factor=2): - return nn.Upsample(scale_factor=scale_factor) - - -# squeeze excitation classes - -# global context network -# https://arxiv.org/abs/2012.13375 -# similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm - -class GlobalContext(nn.Module): - def __init__( - self, - *, - chan_in, - chan_out - ): - super().__init__() - self.to_k = nn.Conv2d(chan_in, 1, 1) - chan_intermediate = max(3, chan_out // 2) - - self.net = nn.Sequential( - nn.Conv2d(chan_in, chan_intermediate, 1), - nn.LeakyReLU(0.1), - nn.Conv2d(chan_intermediate, chan_out, 1), - nn.Sigmoid() - ) - - def forward(self, x): - context = self.to_k(x) - context = context.flatten(2).softmax(dim=-1) - out = einsum('b i n, b c n -> b c i', context, x.flatten(2)) - out = out.unsqueeze(-1) - return self.net(out) - - -# frequency channel attention -# https://arxiv.org/abs/2012.11879 - -def get_1d_dct(i, freq, L): - result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L) - return result * (1 if freq == 0 else math.sqrt(2)) - - -def get_dct_weights(width, channel, fidx_u, fidx_v): - dct_weights = torch.zeros(1, channel, width, width) - c_part = channel // len(fidx_u) - - for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)): - for x in range(width): - for y in range(width): - coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width) - dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value - - return dct_weights - - -class FCANet(nn.Module): - def __init__( - self, - *, - chan_in, - chan_out, - reduction=4, - width - ): - super().__init__() - - freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal - dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]) - self.register_buffer('dct_weights', dct_weights) - - chan_intermediate = max(3, chan_out // reduction) - - self.net = nn.Sequential( - nn.Conv2d(chan_in, chan_intermediate, 1), - nn.LeakyReLU(0.1), - nn.Conv2d(chan_intermediate, chan_out, 1), - nn.Sigmoid() - ) - - def forward(self, x): - x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1=1, w1=1) - return self.net(x) - - -# generative adversarial network - -class Generator(nn.Module): - def __init__( - self, - *, - image_size, - latent_dim=256, - fmap_max=512, - fmap_inverse_coef=12, - transparent=False, - greyscale=False, - freq_chan_attn=False - ): - super().__init__() - resolution = log2(image_size) - assert is_power_of_two(image_size), 'image size must be a power of 2' - - if transparent: - init_channel = 4 - elif greyscale: - init_channel = 1 - else: - init_channel = 3 - - fmap_max = default(fmap_max, latent_dim) - - self.initial_conv = nn.Sequential( - nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), - norm_class(latent_dim * 2), - nn.GLU(dim=1) - ) - - num_layers = int(resolution) - 2 - features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))) - features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) - features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) - features = [latent_dim, *features] - - in_out_features = list(zip(features[:-1], features[1:])) - - self.res_layers = range(2, num_layers + 2) - self.layers = nn.ModuleList([]) - self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) - - self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) - self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) - self.sle_map = dict(self.sle_map) - - self.num_layers_spatial_res = 1 - - for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): - attn = None - sle = None - if res in self.sle_map: - residual_layer = self.sle_map[res] - sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] - - if freq_chan_attn: - sle = FCANet( - chan_in=chan_out, - chan_out=sle_chan_out, - width=2 ** (res + 1) - ) - else: - sle = GlobalContext( - chan_in=chan_out, - chan_out=sle_chan_out - ) - - layer = nn.ModuleList([ - nn.Sequential( - upsample(), - Blur(), - nn.Conv2d(chan_in, chan_out * 2, 3, padding=1), - norm_class(chan_out * 2), - nn.GLU(dim=1) - ), - sle, - attn - ]) - self.layers.append(layer) - - self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1) - - for m in self.modules(): - if type(m) in {nn.Conv2d, nn.Linear}: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - - def forward(self, x): - x = rearrange(x, 'b c -> b c () ()') - x = self.initial_conv(x) - x = F.normalize(x, dim=1) - - residuals = dict() - - for (res, (up, sle, attn)) in zip(self.res_layers, self.layers): - if exists(attn): - x = attn(x) + x - - x = up(x) - - if exists(sle): - out_res = self.sle_map[res] - residual = sle(x) - residuals[out_res] = residual - - next_res = res + 1 - if next_res in residuals: - x = x * residuals[next_res] - - return self.out_conv(x) - - -class SimpleDecoder(nn.Module): - def __init__( - self, - *, - chan_in, - chan_out=3, - num_upsamples=4, - ): - super().__init__() - - self.layers = nn.ModuleList([]) - final_chan = chan_out - chans = chan_in - - for ind in range(num_upsamples): - last_layer = ind == (num_upsamples - 1) - chan_out = chans if not last_layer else final_chan * 2 - layer = nn.Sequential( - upsample(), - nn.Conv2d(chans, chan_out, 3, padding=1), - nn.GLU(dim=1) - ) - self.layers.append(layer) - chans //= 2 - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -class Discriminator(nn.Module): - def __init__( - self, - *, - image_size, - fmap_max=512, - fmap_inverse_coef=12, - transparent=False, - greyscale=False, - disc_output_size=5, - attn_res_layers=[] - ): - super().__init__() - self.image_size = image_size - resolution = log2(image_size) - assert is_power_of_two(image_size), 'image size must be a power of 2' - assert disc_output_size in {1, 5}, 'discriminator output dimensions can only be 5x5 or 1x1' - - resolution = int(resolution) - - if transparent: - init_channel = 4 - elif greyscale: - init_channel = 1 - else: - init_channel = 3 - - num_non_residual_layers = max(0, int(resolution) - 8) - num_residual_layers = 8 - 3 - - non_residual_resolutions = range(min(8, resolution), 2, -1) - features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), non_residual_resolutions)) - features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) - - if num_non_residual_layers == 0: - res, _ = features[0] - features[0] = (res, init_channel) - - chan_in_out = list(zip(features[:-1], features[1:])) - - self.non_residual_layers = nn.ModuleList([]) - for ind in range(num_non_residual_layers): - first_layer = ind == 0 - last_layer = ind == (num_non_residual_layers - 1) - chan_out = features[0][-1] if last_layer else init_channel - - self.non_residual_layers.append(nn.Sequential( - Blur(), - nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1), - nn.LeakyReLU(0.1) - )) - - self.residual_layers = nn.ModuleList([]) - - for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out): - attn = None - self.residual_layers.append(nn.ModuleList([ - SumBranches([ - nn.Sequential( - Blur(), - nn.Conv2d(chan_in, chan_out, 4, stride=2, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(chan_out, chan_out, 3, padding=1), - nn.LeakyReLU(0.1) - ), - nn.Sequential( - Blur(), - nn.AvgPool2d(2), - nn.Conv2d(chan_in, chan_out, 1), - nn.LeakyReLU(0.1), - ) - ]), - attn - ])) - - last_chan = features[-1][-1] - if disc_output_size == 5: - self.to_logits = nn.Sequential( - nn.Conv2d(last_chan, last_chan, 1), - nn.LeakyReLU(0.1), - nn.Conv2d(last_chan, 1, 4) - ) - elif disc_output_size == 1: - self.to_logits = nn.Sequential( - Blur(), - nn.Conv2d(last_chan, last_chan, 3, stride=2, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(last_chan, 1, 4) - ) - - self.to_shape_disc_out = nn.Sequential( - nn.Conv2d(init_channel, 64, 3, padding=1), - Residual(Rezero(GSA(dim=64, norm_queries=True, batch_norm=False))), - SumBranches([ - nn.Sequential( - Blur(), - nn.Conv2d(64, 32, 4, stride=2, padding=1), - nn.LeakyReLU(0.1), - nn.Conv2d(32, 32, 3, padding=1), - nn.LeakyReLU(0.1) - ), - nn.Sequential( - Blur(), - nn.AvgPool2d(2), - nn.Conv2d(64, 32, 1), - nn.LeakyReLU(0.1), - ) - ]), - Residual(Rezero(GSA(dim=32, norm_queries=True, batch_norm=False))), - nn.AdaptiveAvgPool2d((4, 4)), - nn.Conv2d(32, 1, 4) - ) - - self.decoder1 = SimpleDecoder(chan_in=last_chan, chan_out=init_channel) - self.decoder2 = SimpleDecoder(chan_in=features[-2][-1], chan_out=init_channel) if resolution >= 9 else None - - for m in self.modules(): - if type(m) in {nn.Conv2d, nn.Linear}: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - - def forward(self, x, calc_aux_loss=False): - orig_img = x - - for layer in self.non_residual_layers: - x = layer(x) - - layer_outputs = [] - - for (net, attn) in self.residual_layers: - if exists(attn): - x = attn(x) + x - - x = net(x) - layer_outputs.append(x) - - out = self.to_logits(x).flatten(1) - - img_32x32 = F.interpolate(orig_img, size=(32, 32)) - out_32x32 = self.to_shape_disc_out(img_32x32) - - if not calc_aux_loss: - return out, out_32x32, None - - # self-supervised auto-encoding loss - - layer_8x8 = layer_outputs[-1] - layer_16x16 = layer_outputs[-2] - - recon_img_8x8 = self.decoder1(layer_8x8) - - aux_loss = F.mse_loss( - recon_img_8x8, - F.interpolate(orig_img, size=recon_img_8x8.shape[2:]) - ) - - if exists(self.decoder2): - select_random_quadrant = lambda rand_quadrant, img: \ - rearrange(img, 'b c (m h) (n w) -> (m n) b c h w', m=2, n=2)[rand_quadrant] - crop_image_fn = partial(select_random_quadrant, floor(random() * 4)) - img_part, layer_16x16_part = map(crop_image_fn, (orig_img, layer_16x16)) - - recon_img_16x16 = self.decoder2(layer_16x16_part) - - aux_loss_16x16 = F.mse_loss( - recon_img_16x16, - F.interpolate(img_part, size=recon_img_16x16.shape[2:]) - ) - - aux_loss = aux_loss + aux_loss_16x16 - - return out, out_32x32, aux_loss - - -class LightweightGanDivergenceLoss(L.ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.real = opt['real'] - self.fake = opt['fake'] - self.discriminator = opt['discriminator'] - self.for_gen = opt['gen_loss'] - self.gp_frequency = opt['gradient_penalty_frequency'] - self.noise = opt['noise'] if 'noise' in opt.keys() else 0 - # TODO: Implement generator top-k fractional loss compensation. - - def forward(self, net, state): - real_input = state[self.real] - fake_input = state[self.fake] - if self.noise != 0: - fake_input = fake_input + torch.rand_like(fake_input) * self.noise - real_input = real_input + torch.rand_like(real_input) * self.noise - - D = self.env['discriminators'][self.discriminator] - fake, fake32, _ = D(fake_input, detach=not self.for_gen) - if self.for_gen: - return fake.mean() + fake32.mean() - else: - real_input.requires_grad_() # <-- Needed to compute gradients on the input. - real, real32, real_aux = D(real_input, calc_aux_loss=True) - divergence_loss = hinge_loss(real, fake) + hinge_loss(real32, fake32) + real_aux - - # Apply gradient penalty. TODO: migrate this elsewhere. - if self.env['step'] % self.gp_frequency == 0: - gp = gradient_penalty(real_input, real) - self.metrics.append(("gradient_penalty", gp.clone().detach())) - divergence_loss = divergence_loss + gp - - real_input.requires_grad_(requires_grad=False) - return divergence_loss - - -@register_model -def register_lightweight_gan_g(opt_net, opt, other_nets): - gen = Generator(**opt_net['kwargs']) - if opt_get(opt_net, ['ema'], False): - following = other_nets[opt_net['following']] - return EMAWrapper(gen, following, opt_net['rate']) - return gen - - -@register_model -def register_lightweight_gan_d(opt_net, opt): - d = Discriminator(**opt_net['kwargs']) - if opt_net['aug']: - return AugWrapper(d, d.image_size, opt_net['aug_prob'], opt_net['aug_types']) - return d - - -if __name__ == '__main__': - g = Generator(image_size=256) - d = Discriminator(image_size=256) - j = torch.randn(1,256) - r = g(j) - a, b, c = d(r) - print(a.shape) diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index 2a081c47..72a5e188 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -18,12 +18,6 @@ def create_loss(opt_loss, env): elif 'stylegan2_' in type: from models.image_generation.stylegan import create_stylegan2_loss return create_stylegan2_loss(opt_loss, env) - elif 'style_sr_' in type: - from models.styled_sr import create_stylesr_loss - return create_stylesr_loss(opt_loss, env) - elif 'lightweight_gan_divergence' == type: - from models.image_generation.lightweight_gan import LightweightGanDivergenceLoss - return LightweightGanDivergenceLoss(opt_loss, env) elif type == 'crossentropy' or type == 'cross_entropy': return CrossEntropy(opt_loss, env) elif type == 'distillation':