Add a way to disable grad on portions of the generator graph to save memory

This commit is contained in:
James Betker 2020-07-22 11:40:42 -06:00
parent e3adafbeac
commit c74b9ee2e4

View File

@ -41,6 +41,7 @@ class GrowingSRGBase(nn.Module):
self.start_step = start_step self.start_step = start_step
self.latest_step = start_step self.latest_step = start_step
self.fades = [] self.fades = []
self.counter = 0
assert self.upsample_factor == 2 or self.upsample_factor == 4 assert self.upsample_factor == 2 or self.upsample_factor == 4
switches = [] switches = []
@ -105,7 +106,28 @@ class GrowingSRGBase(nn.Module):
# The base param group starts at step 0, the rest are defined via progressive_switches. # The base param group starts at step 0, the rest are defined via progressive_switches.
return [0] + self.progressive_schedule return [0] + self.progressive_schedule
# This method turns requires_grad on and off for different switches, allowing very large models to be trained while
# using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism.
# <groups> controls the proportion of switches that are enabled. 1/groups will be enabled.
# Switches that are younger than 40000 steps are not eligible to be turned off.
def do_switched_grad(self, groups=1):
# If requires_grad is already disabled, don't bother.
if not self.initial_conv.conv.weight.requires_grad or groups == 1:
return
self.counter = (self.counter + 1) % groups
enabled = []
for i, sw in enumerate(self.progressive_switches):
if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter:
for p in sw.parameters():
p.requires_grad = False
else:
enabled.append(i)
for p in sw.parameters():
p.requires_grad = True
def forward(self, x): def forward(self, x):
self.do_switched_grad(2)
x = self.initial_conv(x) x = self.initial_conv(x)
self.attentions = [] self.attentions = []