From 574e7e882b1a14eacad250f473a9d0903793ebf4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 6 May 2020 17:25:25 -0600 Subject: [PATCH] Fix up OOM issues when running a disjoint D update ratio and megabatches --- codes/models/SRGAN_model.py | 16 ++++++++++++---- codes/models/archs/ResGen_arch.py | 4 ++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index a691015e..117cedeb 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -157,6 +157,14 @@ class SRGANModel(BaseModel): if step > self.D_init_iters: self.optimizer_G.zero_grad() + # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. + if step % self.D_update_ratio == 0 and step > self.D_init_iters: + for p in self.netG.parameters(): + p.requires_grad = True + else: + for p in self.netG.parameters(): + p.requires_grad = False + self.fake_GenOut = [] for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): fake_GenOut = self.netG(var_L) @@ -164,23 +172,23 @@ class SRGANModel(BaseModel): # Extract the image output. For generators that output skip-through connections, the master output is always # the first element of the tuple. if isinstance(fake_GenOut, tuple): - fake_H = fake_GenOut[0] + gen_img = fake_GenOut[0] # TODO: Fix this. self.fake_GenOut.append((fake_GenOut[0].detach(), fake_GenOut[1].detach(), fake_GenOut[2].detach())) else: - fake_H = fake_GenOut + gen_img = fake_GenOut self.fake_GenOut.append(fake_GenOut.detach()) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss - l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix) + l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(pix).detach() - fake_fea = self.netF(fake_H) + fake_fea = self.netF(gen_img) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py index 1d908a36..8783e252 100644 --- a/codes/models/archs/ResGen_arch.py +++ b/codes/models/archs/ResGen_arch.py @@ -131,10 +131,10 @@ class FixupResNet(nn.Module): x = self.final_defilter(x) + self.bias2 return x, skip_med, skip_lo -def fixup_resnet34(**kwargs): +def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs): """Constructs a Fixup-ResNet-34 model. """ - model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs) + model = FixupResNet(FixupBasicBlock, [nb_denoiser, nb_upsampler], **kwargs) return model