Add a way to disable grad on portions of the generator graph to save memory
This commit is contained in:
parent
e3adafbeac
commit
c74b9ee2e4
|
@ -41,6 +41,7 @@ class GrowingSRGBase(nn.Module):
|
|||
self.start_step = start_step
|
||||
self.latest_step = start_step
|
||||
self.fades = []
|
||||
self.counter = 0
|
||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||
|
||||
switches = []
|
||||
|
@ -105,7 +106,28 @@ class GrowingSRGBase(nn.Module):
|
|||
# The base param group starts at step 0, the rest are defined via progressive_switches.
|
||||
return [0] + self.progressive_schedule
|
||||
|
||||
# This method turns requires_grad on and off for different switches, allowing very large models to be trained while
|
||||
# using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism.
|
||||
# <groups> controls the proportion of switches that are enabled. 1/groups will be enabled.
|
||||
# Switches that are younger than 40000 steps are not eligible to be turned off.
|
||||
def do_switched_grad(self, groups=1):
|
||||
# If requires_grad is already disabled, don't bother.
|
||||
if not self.initial_conv.conv.weight.requires_grad or groups == 1:
|
||||
return
|
||||
self.counter = (self.counter + 1) % groups
|
||||
enabled = []
|
||||
for i, sw in enumerate(self.progressive_switches):
|
||||
if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter:
|
||||
for p in sw.parameters():
|
||||
p.requires_grad = False
|
||||
else:
|
||||
enabled.append(i)
|
||||
for p in sw.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
def forward(self, x):
|
||||
self.do_switched_grad(2)
|
||||
|
||||
x = self.initial_conv(x)
|
||||
|
||||
self.attentions = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user