Fix up OOM issues when running a disjoint D update ratio and megabatches
This commit is contained in:
parent
eee9d6d9ca
commit
574e7e882b
|
@ -157,6 +157,14 @@ class SRGANModel(BaseModel):
|
||||||
if step > self.D_init_iters:
|
if step > self.D_init_iters:
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
|
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
||||||
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
for p in self.netG.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
else:
|
||||||
|
for p in self.netG.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
self.fake_GenOut = []
|
self.fake_GenOut = []
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
fake_GenOut = self.netG(var_L)
|
fake_GenOut = self.netG(var_L)
|
||||||
|
@ -164,23 +172,23 @@ class SRGANModel(BaseModel):
|
||||||
# Extract the image output. For generators that output skip-through connections, the master output is always
|
# Extract the image output. For generators that output skip-through connections, the master output is always
|
||||||
# the first element of the tuple.
|
# the first element of the tuple.
|
||||||
if isinstance(fake_GenOut, tuple):
|
if isinstance(fake_GenOut, tuple):
|
||||||
fake_H = fake_GenOut[0]
|
gen_img = fake_GenOut[0]
|
||||||
# TODO: Fix this.
|
# TODO: Fix this.
|
||||||
self.fake_GenOut.append((fake_GenOut[0].detach(),
|
self.fake_GenOut.append((fake_GenOut[0].detach(),
|
||||||
fake_GenOut[1].detach(),
|
fake_GenOut[1].detach(),
|
||||||
fake_GenOut[2].detach()))
|
fake_GenOut[2].detach()))
|
||||||
else:
|
else:
|
||||||
fake_H = fake_GenOut
|
gen_img = fake_GenOut
|
||||||
self.fake_GenOut.append(fake_GenOut.detach())
|
self.fake_GenOut.append(fake_GenOut.detach())
|
||||||
|
|
||||||
l_g_total = 0
|
l_g_total = 0
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
if self.cri_pix: # pixel loss
|
if self.cri_pix: # pixel loss
|
||||||
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix)
|
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
|
||||||
l_g_total += l_g_pix
|
l_g_total += l_g_pix
|
||||||
if self.cri_fea: # feature loss
|
if self.cri_fea: # feature loss
|
||||||
real_fea = self.netF(pix).detach()
|
real_fea = self.netF(pix).detach()
|
||||||
fake_fea = self.netF(fake_H)
|
fake_fea = self.netF(gen_img)
|
||||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||||
l_g_total += l_g_fea
|
l_g_total += l_g_fea
|
||||||
|
|
||||||
|
|
|
@ -131,10 +131,10 @@ class FixupResNet(nn.Module):
|
||||||
x = self.final_defilter(x) + self.bias2
|
x = self.final_defilter(x) + self.bias2
|
||||||
return x, skip_med, skip_lo
|
return x, skip_med, skip_lo
|
||||||
|
|
||||||
def fixup_resnet34(**kwargs):
|
def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs):
|
||||||
"""Constructs a Fixup-ResNet-34 model.
|
"""Constructs a Fixup-ResNet-34 model.
|
||||||
"""
|
"""
|
||||||
model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs)
|
model = FixupResNet(FixupBasicBlock, [nb_denoiser, nb_upsampler], **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user