From 1405ff06b8e5a43537312bb3bf818b6623dd60b6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 3 Feb 2021 14:11:55 -0700 Subject: [PATCH] Fix SwitchedConvHardRoutingFunction for current cuda router --- .../switched_conv/switched_conv_hard_routing.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 4cc9d03b..5e590ac7 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -32,13 +32,22 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function): ctx.stride = stride ctx.breadth = s - ctx.save_for_backward(*[input, mask, weight, bias]) + ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias]) return output @staticmethod - def backward(ctx, grad): - input, mask, weight, bias = ctx.saved_tensors - grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride) + def backward(ctx, gradIn): + #import pydevd + #pydevd.settrace(suspend=False, trace_only_current_thread=True) + input, output, mask, weight, bias = ctx.saved_tensors + gradIn = gradIn + + # Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension + # and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s + # and zeros that is multiplied by the output.) + grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1) + + grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride) return grad, grad_sel, grad_w, grad_b, None