From 1405ff06b8e5a43537312bb3bf818b6623dd60b6 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Wed, 3 Feb 2021 14:11:55 -0700
Subject: [PATCH] Fix SwitchedConvHardRoutingFunction for current cuda router

---
 .../switched_conv/switched_conv_hard_routing.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py
index 4cc9d03b..5e590ac7 100644
--- a/codes/models/switched_conv/switched_conv_hard_routing.py
+++ b/codes/models/switched_conv/switched_conv_hard_routing.py
@@ -32,13 +32,22 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function):
 
         ctx.stride = stride
         ctx.breadth = s
-        ctx.save_for_backward(*[input, mask, weight, bias])
+        ctx.save_for_backward(*[input, output.detach().clone(), mask, weight, bias])
         return output
 
     @staticmethod
-    def backward(ctx, grad):
-        input, mask, weight, bias = ctx.saved_tensors
-        grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
+    def backward(ctx, gradIn):
+        #import pydevd
+        #pydevd.settrace(suspend=False, trace_only_current_thread=True)
+        input, output, mask, weight, bias = ctx.saved_tensors
+        gradIn = gradIn
+
+        # Selector grad is simply the element-wise product of grad with the output of the layer, summed across the channel dimension
+        # and repeated along the breadth of the switch. (Think of the forward operation using the selector as a simple matrix of 1s
+        # and zeros that is multiplied by the output.)
+        grad_sel = (gradIn * output).sum(dim=1, keepdim=True).repeat(1,ctx.breadth,1,1)
+
+        grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, gradIn.contiguous(), mask, weight, bias, ctx.stride)
         return grad, grad_sel, grad_w, grad_b, None