Fix up OOM issues when running a disjoint D update ratio and megabatches
This commit is contained in:
parent
eee9d6d9ca
commit
574e7e882b
|
@ -157,6 +157,14 @@ class SRGANModel(BaseModel):
|
|||
if step > self.D_init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
for p in self.netG.parameters():
|
||||
p.requires_grad = True
|
||||
else:
|
||||
for p in self.netG.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.fake_GenOut = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
fake_GenOut = self.netG(var_L)
|
||||
|
@ -164,23 +172,23 @@ class SRGANModel(BaseModel):
|
|||
# Extract the image output. For generators that output skip-through connections, the master output is always
|
||||
# the first element of the tuple.
|
||||
if isinstance(fake_GenOut, tuple):
|
||||
fake_H = fake_GenOut[0]
|
||||
gen_img = fake_GenOut[0]
|
||||
# TODO: Fix this.
|
||||
self.fake_GenOut.append((fake_GenOut[0].detach(),
|
||||
fake_GenOut[1].detach(),
|
||||
fake_GenOut[2].detach()))
|
||||
else:
|
||||
fake_H = fake_GenOut
|
||||
gen_img = fake_GenOut
|
||||
self.fake_GenOut.append(fake_GenOut.detach())
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix)
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(pix).detach()
|
||||
fake_fea = self.netF(fake_H)
|
||||
fake_fea = self.netF(gen_img)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
|
|
|
@ -131,10 +131,10 @@ class FixupResNet(nn.Module):
|
|||
x = self.final_defilter(x) + self.bias2
|
||||
return x, skip_med, skip_lo
|
||||
|
||||
def fixup_resnet34(**kwargs):
|
||||
def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs):
|
||||
"""Constructs a Fixup-ResNet-34 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs)
|
||||
model = FixupResNet(FixupBasicBlock, [nb_denoiser, nb_upsampler], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user