from typing import Optional, List

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.utils import _ntuple
import torch.nn.functional as F


_pair = _ntuple(2)


# Indexes the <p> index of input=b,c,h,w,p by the long tensor index=b,1,h,w. Result is b,c,h,w.
# Frankly - IMO - this is what torch.gather should do.
def index_2d(input, index):
    index = index.repeat(1,input.shape[1],1,1)
    e = torch.eye(input.shape[-1], device=input.device)
    result = e[index] * input
    return result.sum(-1)


# Drop-in implementation of Conv2d that can apply masked scales&shifts to the convolution weights.
class ScaledWeightConv(_ConvNd):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size,
        stride = 1,
        padding = 0,
        dilation = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',
        breadth: int = 8,
    ):
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super().__init__(
            in_channels, out_channels, _pair(kernel_size), stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

        self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)])
        self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)])
        for w, s in zip(self.weight_scales, self.shifts):
            w.FOR_SCALE_SHIFT = True
            s.FOR_SCALE_SHIFT = True
        # This should probably be configurable at some point.
        self.weight.DO_NOT_TRAIN = True
        self.weight.requires_grad = False

    def _weighted_conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input: Tensor, masks: dict = None) -> Tensor:
        if masks is None:
            # An alternate "mode" of operation is the masks are injected as parameters.
            assert hasattr(self, 'masks')
            masks = self.masks

        # This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any
        # good at all, this can be made more efficient by performing a single conv pass with multiple masks.
        weighted_convs = [self._weighted_conv_forward(input, self.weight * scale + shift) for scale, shift in zip(self.weight_scales, self.shifts)]
        weighted_convs = torch.stack(weighted_convs, dim=-1)

        needed_mask = weighted_convs.shape[-2]
        assert needed_mask in masks.keys()

        return index_2d(weighted_convs, masks[needed_mask])


def create_wrapped_conv_from_template(conv: nn.Conv2d, breadth: int):
    wrapped = ScaledWeightConv(conv.in_channels,
                               conv.out_channels,
                               conv.kernel_size[0],
                               conv.stride[0],
                               conv.padding[0],
                               conv.dilation[0],
                               conv.groups,
                               conv.bias,
                               conv.padding_mode,
                               breadth)
    return wrapped


# Drop-in implementation of ConvTranspose2d that can apply masked scales&shifts to the convolution weights.
class ScaledWeightConvTranspose(_ConvTransposeNd):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size,
        stride = 1,
        padding = 0,
        output_padding = 0,
        groups: int = 1,
        bias: bool = True,
        dilation: int = 1,
        padding_mode: str = 'zeros',
        breadth: int = 8,
    ):
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        output_padding = _pair(output_padding)
        super().__init__(
            in_channels, out_channels, _pair(kernel_size), stride, padding, dilation,
            True, output_padding, groups, bias, padding_mode)

        self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)])
        self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)])
        for w, s in zip(self.weight_scales, self.shifts):
            w.FOR_SCALE_SHIFT = True
            s.FOR_SCALE_SHIFT = True
        # This should probably be configurable at some point.
        self.weight.DO_NOT_TRAIN = True
        self.weight.requires_grad = False

    def _conv_transpose_forward(self, input, weight, output_size) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')

        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)

        return F.conv_transpose2d(
            input, weight, self.bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

    def forward(self, input: Tensor, masks: dict = None, output_size: Optional[List[int]] = None) -> Tensor:
        if masks is None:
            # An alternate "mode" of operation is the masks are injected as parameters.
            assert hasattr(self, 'masks')
            masks = self.masks

        # This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any
        # good at all, this can be made more efficient by performing a single conv pass with multiple masks.
        weighted_convs = [self._conv_transpose_forward(input, self.weight * scale + shift, output_size)
                          for scale, shift in zip(self.weight_scales, self.shifts)]
        weighted_convs = torch.stack(weighted_convs, dim=-1)

        needed_mask = weighted_convs.shape[-2]
        assert needed_mask in masks.keys()

        return index_2d(weighted_convs, masks[needed_mask])


def create_wrapped_conv_transpose_from_template(conv: nn.Conv2d, breadth: int):
    wrapped = ScaledWeightConvTranspose(conv.in_channels,
                               conv.out_channels,
                               conv.kernel_size,
                               conv.stride,
                               conv.padding,
                               conv.output_padding,
                               conv.groups,
                               conv.bias,
                               conv.dilation,
                               conv.padding_mode,
                               breadth)
    wrapped.weight = conv.weight
    wrapped.weight.DO_NOT_TRAIN = True
    wrapped.weight.requires_grad = False
    wrapped.bias = conv.bias
    return wrapped