# Heavily based on the lucidrains stylegan2 discriminator implementation.
import math
import os
from functools import partial
from math import log2
from random import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import grad as torch_grad
import trainer.losses as L
from vector_quantize_pytorch import VectorQuantize

from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists
from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear
from trainer.networks import register_model
from utils.util import checkpoint, opt_get


class DiscriminatorBlock(nn.Module):
    def __init__(self, input_channels, filters, downsample=True, transfer_mode=False):
        super().__init__()
        self.filters = filters
        self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode)

        self.net = nn.Sequential(
            TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode),
            leaky_relu(),
            TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode),
            leaky_relu()
        )

        self.downsample = nn.Sequential(
            Blur(),
            TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode)
        ) if downsample else None

    def forward(self, x):
        res = self.conv_res(x)
        x = self.net(x)
        if exists(self.downsample):
            x = self.downsample(x)
        x = (x + res) * (1 / math.sqrt(2))
        return x


class StyleSrDiscriminator(nn.Module):
    def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
                 transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False,
                 transfer_mode=False):
        super().__init__()
        num_layers = int(log2(image_size) - 1)

        blocks = []
        filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        chan_in_out = list(zip(filters[:-1], filters[1:]))

        blocks = []
        attn_blocks = []
        quantize_blocks = []

        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode)
            blocks.append(block)

            attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None

            attn_blocks.append(attn_fn)

            if quantize:
                quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
                quantize_blocks.append(quantize_fn)
            else:
                quantize_blocks.append(None)

        self.blocks = nn.ModuleList(blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)
        self.quantize_blocks = nn.ModuleList(quantize_blocks)
        self.do_checkpointing = do_checkpointing

        chan_last = filters[-1]
        latent_dim = 2 * 2 * chan_last

        self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode)
        self.flatten = nn.Flatten()
        if mlp:
            self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
                                          leaky_relu(),
                                          TransferLinear(100, 1, transfer_mode=transfer_mode))
        else:
            self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode)

        self._init_weights()

        self.transfer_mode = transfer_mode
        if transfer_mode:
            for p in self.parameters():
                if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
                    p.DO_NOT_TRAIN = True

    def forward(self, x):
        b, *_ = x.shape

        quantize_loss = torch.zeros(1).to(x)

        for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
            if self.do_checkpointing:
                x = checkpoint(block, x)
            else:
                x = block(x)

            if exists(attn_block):
                x = attn_block(x)

            if exists(q_block):
                x, _, loss = q_block(x)
                quantize_loss += loss

        x = self.final_conv(x)
        x = self.flatten(x)
        x = self.to_logit(x)
        if exists(q_block):
            return x.squeeze(), quantize_loss
        else:
            return x.squeeze()

    def _init_weights(self):
        for m in self.modules():
            if type(m) in {TransferConv2d, TransferLinear}:
                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

    # Configures the network as partially pre-trained. This means:
    # 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
    # 2) The head (linear layers) will also have their weights re-initialized
    # 3) All intermediate blocks will be frozen until step `frozen_until_step`
    # These settings will be applied after the weights have been loaded (network_loaded())
    def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
        self.bypass_blocks = bypass_blocks
        self.num_blocks = num_blocks
        self.frozen_until_step = frozen_until_step

    # Called after the network weights are loaded.
    def network_loaded(self):
        if not hasattr(self, 'frozen_until_step'):
            return

        if self.bypass_blocks > 0:
            self.blocks = self.blocks[self.bypass_blocks:]
            self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device)

        reset_blocks = [self.to_logit]
        for i in range(self.num_blocks):
            reset_blocks.append(self.blocks[i])
        for bl in reset_blocks:
            for m in bl.modules():
                if type(m) in {TransferConv2d, TransferLinear}:
                    nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
                for p in m.parameters(recurse=True):
                    p._NEW_BLOCK = True
        for p in self.parameters():
            if not hasattr(p, '_NEW_BLOCK'):
                p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step


# helper classes
def DiffAugment(x, types=[]):
    for p in types:
        for f in AUGMENT_FNS[p]:
            x = f(x)
    return x.contiguous()


def random_hflip(tensor, prob):
    if prob > random():
        return tensor
    return torch.flip(tensor, dims=(3,))


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}


class DiscAugmentor(nn.Module):
    def __init__(self, D, image_size, types, prob):
        super().__init__()
        self.D = D
        self.prob = prob
        self.types = types

    def forward(self, images, real_images=False):
        if random() < self.prob:
            images = random_hflip(images, prob=0.5)
            images = DiffAugment(images, types=self.types)

        if real_images:
            self.hq_aug = images.detach().clone()
        else:
            self.gen_aug = images.detach().clone()

        # Save away for use elsewhere (e.g. unet loss)
        self.aug_images = images

        return self.D(images)

    def network_loaded(self):
        self.D.network_loaded()

    # Allows visualizing what the augmentor is up to.
    def visual_dbg(self, step, path):
        torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step)))
        torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step)))


def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
    if fp16:
        with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
            scaled_loss.backward(**kwargs)
    else:
        loss.backward(**kwargs)


def gradient_penalty(images, output, weight=10, return_structured_grads=False):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs=output, inputs=images,
                           grad_outputs=torch.ones(output.size(), device=images.device),
                           create_graph=True, retain_graph=True, only_inputs=True)[0]

    flat_grad = gradients.reshape(batch_size, -1)
    penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
    if return_structured_grads:
        return penalty, gradients
    else:
        return penalty


class StyleSrGanDivergenceLoss(L.ConfigurableLoss):
    def __init__(self, opt, env):
        super().__init__(opt, env)
        self.real = opt['real']
        self.fake = opt['fake']
        self.discriminator = opt['discriminator']
        self.for_gen = opt['gen_loss']
        self.gp_frequency = opt['gradient_penalty_frequency']
        self.noise = opt['noise'] if 'noise' in opt.keys() else 0

    def forward(self, net, state):
        real_input = state[self.real]
        fake_input = state[self.fake]
        if self.noise != 0:
            fake_input = fake_input + torch.rand_like(fake_input) * self.noise
            real_input = real_input + torch.rand_like(real_input) * self.noise

        D = self.env['discriminators'][self.discriminator]
        fake = D(fake_input, real_images=False)
        if self.for_gen:
            return fake.mean()
        else:
            real_input.requires_grad_()  # <-- Needed to compute gradients on the input.
            real = D(real_input, real_images=True)
            divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()

            # Apply gradient penalty. TODO: migrate this elsewhere.
            if self.env['step'] % self.gp_frequency == 0:
                gp = gradient_penalty(real_input, real)
                self.metrics.append(("gradient_penalty", gp.clone().detach()))
                divergence_loss = divergence_loss + gp

            real_input.requires_grad_(requires_grad=False)
            return divergence_loss


@register_model
def register_styledsr_discriminator(opt_net, opt):
    attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
    disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
                                do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
                                quantize=opt_get(opt_net, ['quantize'], False),
                                mlp=opt_get(opt_net, ['mlp_head'], True),
                                transfer_mode=opt_get(opt_net, ['transfer_mode'], False)
                                )
    if 'use_partial_pretrained' in opt_net.keys():
        disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
    return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])