import torch
import torch.nn as nn
from models.networks import define_F
from models.loss import GANLoss
from torchvision.utils import save_image


def create_generator_loss(opt_loss, env):
    type = opt_loss['type']
    if type == 'pix':
        return PixLoss(opt_loss, env)
    elif type == 'feature':
        return FeatureLoss(opt_loss, env)
    elif type == 'generator_gan':
        return GeneratorGanLoss(opt_loss, env)
    elif type == 'discriminator_gan':
        return DiscriminatorGanLoss(opt_loss, env)
    else:
        raise NotImplementedError


class ConfigurableLoss(nn.Module):
    def __init__(self, opt, env):
        super(ConfigurableLoss, self).__init__()
        self.opt = opt
        self.env = env
        self.metrics = []

    def forward(self, net, state):
        raise NotImplementedError

    def extra_metrics(self):
        return self.metrics


def get_basic_criterion_for_name(name, device):
    if name == 'l1':
        return nn.L1Loss().to(device)
    elif name == 'l2':
        return nn.MSELoss().to(device)
    else:
        raise NotImplementedError


class PixLoss(ConfigurableLoss):
    def __init__(self, opt, env):
        super(PixLoss, self).__init__(opt, env)
        self.opt = opt
        self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])

    def forward(self, net, state):
        return self.criterion(state[self.opt['fake']], state[self.opt['real']])


class FeatureLoss(ConfigurableLoss):
    def __init__(self, opt, env):
        super(FeatureLoss, self).__init__(opt, env)
        self.opt = opt
        self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
        self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
        if not env['opt']['dist']:
            self.netF = torch.nn.parallel.DataParallel(self.netF)

    def forward(self, net, state):
        with torch.no_grad():
            logits_real = self.netF(state[self.opt['real']])
        logits_fake = self.netF(state[self.opt['fake']])
        return self.criterion(logits_fake, logits_real)


class GeneratorGanLoss(ConfigurableLoss):
    def __init__(self, opt, env):
        super(GeneratorGanLoss, self).__init__(opt, env)
        self.opt = opt
        self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])

    def forward(self, net, state):
        netD = self.env['discriminators'][self.opt['discriminator']]
        if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
            if self.opt['gan_type'] == 'crossgan':
                pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref'])
            else:
                pred_g_fake = netD(state[self.opt['fake']])
            return self.criterion(pred_g_fake, True)
        elif self.opt['gan_type'] == 'ragan':
            pred_d_real = netD(state[self.opt['real']]).detach()
            pred_g_fake = netD(state[self.opt['fake']])
            return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
                    self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
        else:
            raise NotImplementedError


class DiscriminatorGanLoss(ConfigurableLoss):
    def __init__(self, opt, env):
        super(DiscriminatorGanLoss, self).__init__(opt, env)
        self.opt = opt
        self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])

    def forward(self, net, state):
        self.metrics = []

        if self.opt['gan_type'] == 'crossgan':
            d_real = net(state[self.opt['real']], state['lq_fullsize_ref'])
            d_fake = net(state[self.opt['fake']].detach(), state['lq_fullsize_ref'])
            mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0)
            d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
            d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
        else:
            d_real = net(state[self.opt['real']])
            d_fake = net(state[self.opt['fake']].detach())
        self.metrics.append(("d_fake", torch.mean(d_fake)))

        if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan']:
            l_real = self.criterion(d_real, True)
            l_fake = self.criterion(d_fake, False)
            l_total = l_real + l_fake
            if self.opt['gan_type'] == 'crossgan':
                l_mreal = self.criterion(d_mismatch_real, False)
                l_mfake = self.criterion(d_mismatch_fake, False)
                l_total += l_mreal + l_mfake
                self.metrics.append(("l_mismatch", l_mfake + l_mreal))
            self.metrics.append(("l_fake", l_fake))
            return l_total
        elif self.opt['gan_type'] == 'ragan':
            return (self.cri_gan(d_real - torch.mean(d_fake), True) +
                    self.cri_gan(d_fake - torch.mean(d_real), False))
        else:
            raise NotImplementedError