Fix SwitchedConvHardRoutingFunction for current cuda router

This commit is contained in:
James Betker 2021-02-03 14:11:55 -07:00
parent d7bec392dd
commit 1405ff06b8

View File

@ -32,13 +32,22 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function):
ctx.stride = stride
ctx.breadth = s
ctx.save_for_backward(*[input, mask, weight, bias])
ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias])
return output
@staticmethod
def backward(ctx, grad):
input, mask, weight, bias = ctx.saved_tensors
grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
def backward(ctx, gradIn):
#import pydevd
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
input, output, mask, weight, bias = ctx.saved_tensors
gradIn = gradIn
# Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension
# and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s
# and zeros that is multiplied by the output.)
grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1)
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride)
return grad, grad_sel, grad_w, grad_b, None