from random import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.arch_util import kaiming_init
from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock
from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear
from trainer.networks import register_model
from utils.util import checkpoint, opt_get


def rrdb_init_weights(module, scale=1):
    for m in module.modules():
        if isinstance(m, TransferConv2d):
            kaiming_init(m, a=0, mode='fan_in', bias=0)
            m.weight.data *= scale
        elif isinstance(m, TransferLinear):
            kaiming_init(m, a=0, mode='fan_in', bias=0)
            m.weight.data *= scale


class EncoderRRDB(nn.Module):
    def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False):
        super(EncoderRRDB, self).__init__()
        for i in range(5):
            out_channels = output_channels if i == 4 else growth_channels
            self.add_module(
                f'conv{i+1}',
                TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode))
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        for i in range(5):
            rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5


class StyledSrEncoder(nn.Module):
    def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False):
        super().__init__()
        # Current assumes fea_out=256.
        self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode)
        self.rrdbs = nn.ModuleList([
           EncoderRRDB(32, transfer_mode=transfer_mode),
           EncoderRRDB(64, transfer_mode=transfer_mode),
           EncoderRRDB(96, transfer_mode=transfer_mode),
           EncoderRRDB(128, transfer_mode=transfer_mode),
           EncoderRRDB(160, transfer_mode=transfer_mode),
           EncoderRRDB(192, transfer_mode=transfer_mode),
           EncoderRRDB(224, transfer_mode=transfer_mode)])

    def forward(self, x):
        fea = self.initial_conv(x)
        for rrdb in self.rrdbs:
            fea = torch.cat([fea, checkpoint(rrdb, fea)], dim=1)
        return fea


class Generator(nn.Module):
    def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False):
        super().__init__()
        total_levels = upsample_levels + 1  # The first level handles the raw encoder output and doesn't upsample.
        self.image_size = image_size
        self.scale = 2 ** upsample_levels
        self.latent_dim = latent_dim
        self.num_layers = total_levels
        self.transfer_mode = transfer_mode
        filters = [
            512,  # 4x4
            512,  # 8x8
            512,  # 16x16
            256,  # 32x32
            128,  # 64x64
            64,   # 128x128
            32,   # 256x256
            16,   # 512x512
            8,    # 1024x1024
        ]

        # I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted.
        self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False)

        in_out_pairs = list(zip(filters[:-1], filters[1:]))
        self.blocks = nn.ModuleList([])
        for ind in range(start_level, start_level+total_levels):
            in_chan, out_chan = in_out_pairs[ind]
            not_first = ind != start_level
            not_last = ind != (start_level+total_levels-1)
            block = GeneratorBlock(
                latent_dim,
                in_chan,
                out_chan,
                upsample=not_first,
                upsample_rgb=not_last,
                transfer_learning_mode=transfer_mode
            )
            self.blocks.append(block)

    def forward(self, lr, styles):
        b, c, h, w = lr.shape
        if self.transfer_mode:
            with torch.no_grad():
                x = self.encoder(lr)
        else:
            x = self.encoder(lr)

        styles = styles.transpose(0, 1)
        input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
        if h != x.shape[-2]:
            rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
        else:
            rgb = lr

        for style, block in zip(styles, self.blocks):
            x, rgb = checkpoint(block, x, rgb, style, input_noise)

        return rgb


class StyledSrGenerator(nn.Module):
    def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False):
        super().__init__()
        # Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
        self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
        self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
        self.l2 = nn.MSELoss()
        self.mixed_prob = .9
        self._init_weights()
        self.transfer_mode = transfer_mode
        self.initial_stride = initial_stride
        if transfer_mode:
            for p in self.parameters():
                if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
                    p.DO_NOT_TRAIN = True


    def _init_weights(self):
        for m in self.modules():
            if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

        for block in self.gen.blocks:
            nn.init.zeros_(block.to_noise1.weight)
            nn.init.zeros_(block.to_noise2.weight)
            nn.init.zeros_(block.to_noise1.bias)
            nn.init.zeros_(block.to_noise2.bias)

    def forward(self, x):
        b, f, h, w = x.shape

        # Synthesize style latents from noise.
        style = torch.randn(b*2, self.gen.latent_dim).to(x.device)
        if self.transfer_mode:
            with torch.no_grad():
                w = self.vectorizer(style)
        else:
            w = self.vectorizer(style)

        # Randomly distribute styles across layers
        w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
        for j in range(b):
            cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
            if cutoff == self.gen.num_layers or random() > self.mixed_prob:
                w_styles[j] = w_styles[j*2]
            else:
                w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
                w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
        w_styles = w_styles[:b]

        out = self.gen(x, w_styles)

        # Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used
        # for regularization.
        out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area")
        if self.initial_stride > 1:
            x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area")
        l2_reg = self.l2(x, out_down)

        return out, l2_reg, w_styles


if __name__ == '__main__':
    gen = StyledSrGenerator(128, 2)
    out = gen(torch.rand(1,3,64,64))
    print([o.shape for o in out])


@register_model
def register_styled_sr(opt_net, opt):
    return StyledSrGenerator(128,
                             initial_stride=opt_get(opt_net, ['initial_stride'], 1),
                             transfer_mode=opt_get(opt_net, ['transfer_mode'], False))