diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 4cc9d03b..5e590ac7 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -32,13 +32,22 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function): ctx.stride = stride ctx.breadth = s - ctx.save_for_backward(*[input, mask, weight, bias]) + ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias]) return output @staticmethod - def backward(ctx, grad): - input, mask, weight, bias = ctx.saved_tensors - grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride) + def backward(ctx, gradIn): + #import pydevd + #pydevd.settrace(suspend=False, trace_only_current_thread=True) + input, output, mask, weight, bias = ctx.saved_tensors + gradIn = gradIn + + # Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension + # and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s + # and zeros that is multiplied by the output.) + grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1) + + grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride) return grad, grad_sel, grad_w, grad_b, None