forked from mrq/DL-Art-School
40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
|
from torch.autograd import Function, Variable
|
||
|
from torch.nn.modules.module import Module
|
||
|
import channelnorm_cuda
|
||
|
|
||
|
class ChannelNormFunction(Function):
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, input1, norm_deg=2):
|
||
|
assert input1.is_contiguous()
|
||
|
b, _, h, w = input1.size()
|
||
|
output = input1.new(b, 1, h, w).zero_()
|
||
|
|
||
|
channelnorm_cuda.forward(input1, output, norm_deg)
|
||
|
ctx.save_for_backward(input1, output)
|
||
|
ctx.norm_deg = norm_deg
|
||
|
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
input1, output = ctx.saved_tensors
|
||
|
|
||
|
grad_input1 = Variable(input1.new(input1.size()).zero_())
|
||
|
|
||
|
channelnorm_cuda.backward(input1, output, grad_output.data,
|
||
|
grad_input1.data, ctx.norm_deg)
|
||
|
|
||
|
return grad_input1, None
|
||
|
|
||
|
|
||
|
class ChannelNorm(Module):
|
||
|
|
||
|
def __init__(self, norm_deg=2):
|
||
|
super(ChannelNorm, self).__init__()
|
||
|
self.norm_deg = norm_deg
|
||
|
|
||
|
def forward(self, input1):
|
||
|
return ChannelNormFunction.apply(input1, self.norm_deg)
|
||
|
|