Fix up OOM issues when running a disjoint D update ratio and megabatches

This commit is contained in:
James Betker 2020-05-06 17:25:25 -06:00
parent eee9d6d9ca
commit 574e7e882b
2 changed files with 14 additions and 6 deletions

View File

@ -157,6 +157,14 @@ class SRGANModel(BaseModel):
if step > self.D_init_iters:
self.optimizer_G.zero_grad()
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
for p in self.netG.parameters():
p.requires_grad = True
else:
for p in self.netG.parameters():
p.requires_grad = False
self.fake_GenOut = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
fake_GenOut = self.netG(var_L)
@ -164,23 +172,23 @@ class SRGANModel(BaseModel):
# Extract the image output. For generators that output skip-through connections, the master output is always
# the first element of the tuple.
if isinstance(fake_GenOut, tuple):
fake_H = fake_GenOut[0]
gen_img = fake_GenOut[0]
# TODO: Fix this.
self.fake_GenOut.append((fake_GenOut[0].detach(),
fake_GenOut[1].detach(),
fake_GenOut[2].detach()))
else:
fake_H = fake_GenOut
gen_img = fake_GenOut
self.fake_GenOut.append(fake_GenOut.detach())
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix)
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(pix).detach()
fake_fea = self.netF(fake_H)
fake_fea = self.netF(gen_img)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea

View File

@ -131,10 +131,10 @@ class FixupResNet(nn.Module):
x = self.final_defilter(x) + self.bias2
return x, skip_med, skip_lo
def fixup_resnet34(**kwargs):
def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs):
"""Constructs a Fixup-ResNet-34 model.
"""
model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs)
model = FixupResNet(FixupBasicBlock, [nb_denoiser, nb_upsampler], **kwargs)
return model