DL-Art-School/codes/models/layers/channelnorm_package/channelnorm.py

40 lines
1.1 KiB
Python
Raw Normal View History

2020-09-19 16:07:00 +00:00
from torch.autograd import Function, Variable
from torch.nn.modules.module import Module
import channelnorm_cuda
class ChannelNormFunction(Function):
@staticmethod
def forward(ctx, input1, norm_deg=2):
assert input1.is_contiguous()
b, _, h, w = input1.size()
output = input1.new(b, 1, h, w).zero_()
channelnorm_cuda.forward(input1, output, norm_deg)
ctx.save_for_backward(input1, output)
ctx.norm_deg = norm_deg
return output
@staticmethod
def backward(ctx, grad_output):
input1, output = ctx.saved_tensors
grad_input1 = Variable(input1.new(input1.size()).zero_())
channelnorm_cuda.backward(input1, output, grad_output.data,
grad_input1.data, ctx.norm_deg)
return grad_input1, None
class ChannelNorm(Module):
def __init__(self, norm_deg=2):
super(ChannelNorm, self).__init__()
self.norm_deg = norm_deg
def forward(self, input1):
return ChannelNormFunction.apply(input1, self.norm_deg)