Fix SPSR calls into SwitchComputer
This commit is contained in:
parent
bdf4c38899
commit
cc915303a5
|
@ -271,18 +271,18 @@ class Spsr5(nn.Module):
|
|||
|
||||
x = self.model_fea_conv(x)
|
||||
x1 = x
|
||||
x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, embedding))
|
||||
x1, a1 = self.sw1(x1, identity=x, att_in=(x1, embedding))
|
||||
|
||||
x2 = x1
|
||||
x2, nstd = self.noise_ref_join(x2, torch.randn_like(x2))
|
||||
x2, a2 = self.sw2(x2, True, identity=x1, att_in=(x2, embedding))
|
||||
x2, a2 = self.sw2(x2, identity=x1, att_in=(x2, embedding))
|
||||
noise_stds.append(nstd)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
x_grad, nstd = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
||||
x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, embedding))
|
||||
x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity, att_in=(x_grad, embedding))
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_lr_conv2(x_grad)
|
||||
x_grad_out = self.upsample_grad(x_grad)
|
||||
|
@ -292,7 +292,7 @@ class Spsr5(nn.Module):
|
|||
x_out = x2
|
||||
x_out, nstd = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, embedding))
|
||||
x_out, a4 = self.conjoin_sw(x_out, identity=x2, att_in=(x_out, embedding))
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
|
@ -404,15 +404,15 @@ class Spsr6(nn.Module):
|
|||
|
||||
x = self.model_fea_conv(x)
|
||||
x1 = x
|
||||
x1, a1 = self.sw1(x1, True, identity=x)
|
||||
x1, a1 = self.sw1(x1, identity=x)
|
||||
|
||||
x2 = x1
|
||||
x2, a2 = self.sw2(x2, True, identity=x1)
|
||||
x2, a2 = self.sw2(x2, identity=x1)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
|
||||
x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_lr_conv2(x_grad)
|
||||
x_grad_out = self.upsample_grad(x_grad)
|
||||
|
@ -420,7 +420,7 @@ class Spsr6(nn.Module):
|
|||
|
||||
x_out = x2
|
||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
||||
x_out, a4 = self.conjoin_sw(x_out, identity=x2)
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = checkpoint(self.upsample, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv1, x_out)
|
||||
|
@ -543,15 +543,15 @@ class Spsr7(nn.Module):
|
|||
x = x + br
|
||||
|
||||
x1 = x
|
||||
x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding))
|
||||
x1, a1 = self.sw1(x1, identity=x, att_in=(x1, ref_embedding), do_checkpointing=True)
|
||||
|
||||
x2 = x1
|
||||
x2, a2 = self.sw2(x2, True, identity=x1, att_in=(x2, ref_embedding))
|
||||
x2, a2 = self.sw2(x2, identity=x1, att_in=(x2, ref_embedding), do_checkpointing=True)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
x_grad, grad_fea_std = checkpoint(self.grad_ref_join, x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding))
|
||||
x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity, att_in=(x_grad, ref_embedding), do_checkpointing=True)
|
||||
x_grad = checkpoint(self.grad_lr_conv, x_grad)
|
||||
x_grad = checkpoint(self.grad_lr_conv2, x_grad)
|
||||
x_grad_out = checkpoint(self.upsample_grad, x_grad)
|
||||
|
@ -559,7 +559,7 @@ class Spsr7(nn.Module):
|
|||
|
||||
x_out = x2
|
||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding))
|
||||
x_out, a4 = self.conjoin_sw(x_out, identity=x2, att_in=(x_out, ref_embedding), do_checkpointing=True)
|
||||
x_out = checkpoint(self.final_lr_conv, x_out)
|
||||
x_out = checkpoint(self.upsample, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv1, x_out)
|
||||
|
@ -620,9 +620,9 @@ class AttentionBlock(nn.Module):
|
|||
def forward(self, x, mplex_ref=None, ref=None):
|
||||
if self.ref_join is not None:
|
||||
branch, ref_std = self.ref_join(x, ref)
|
||||
return self.switch(branch, True, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
|
||||
return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
|
||||
else:
|
||||
return self.switch(x, True, identity=x, att_in=(x, mplex_ref))
|
||||
return self.switch(x, identity=x, att_in=(x, mplex_ref))
|
||||
|
||||
|
||||
# SPSR7 with incremental improvements and also using the new AttentionBlock to save gpu memory.
|
||||
|
|
Loading…
Reference in New Issue
Block a user