Fix SwitchedConvHardRoutingFunction for current cuda router
This commit is contained in:
parent
d7bec392dd
commit
1405ff06b8
|
@ -32,13 +32,22 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function):
|
|||
|
||||
ctx.stride = stride
|
||||
ctx.breadth = s
|
||||
ctx.save_for_backward(*[input, mask, weight, bias])
|
||||
ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias])
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
input, mask, weight, bias = ctx.saved_tensors
|
||||
grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
|
||||
def backward(ctx, gradIn):
|
||||
#import pydevd
|
||||
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||
input, output, mask, weight, bias = ctx.saved_tensors
|
||||
gradIn = gradIn
|
||||
|
||||
# Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension
|
||||
# and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s
|
||||
# and zeros that is multiplied by the output.)
|
||||
grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1)
|
||||
|
||||
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride)
|
||||
return grad, grad_sel, grad_w, grad_b, None
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user