From e4b89a172f1e76cd385f68c07360b1097e5e77ab Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 5 Oct 2020 22:05:56 -0600 Subject: [PATCH] Reduce spsr7 memory usage --- codes/models/archs/SPSR_arch.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index f72e26d8..b719ac0f 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -539,20 +539,20 @@ class Spsr7(nn.Module): x_grad = self.grad_conv(x_grad) x_grad_identity = x_grad - x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1) + x_grad, grad_fea_std = checkpoint(self.grad_ref_join, x_grad, x1) x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding)) - x_grad = self.grad_lr_conv(x_grad) - x_grad = self.grad_lr_conv2(x_grad) - x_grad_out = self.upsample_grad(x_grad) - x_grad_out = self.grad_branch_output_conv(x_grad_out) + x_grad = checkpoint(self.grad_lr_conv, x_grad) + x_grad = checkpoint(grad_lr_conv2, x_grad) + x_grad_out = checkpoint(self.upsample_grad, x_grad) + x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) x_out = x2 x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding)) - x_out = self.final_lr_conv(x_out) + x_out = checkpoint(self.final_lr_conv, x_out) x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.final_hr_conv1, x_out) - x_out = self.final_hr_conv2(x_out) + x_out = checkpoint(self.final_hr_conv2, x_out) self.attentions = [a1, a2, a3, a4] self.grad_fea_std = grad_fea_std.detach().cpu()