forked from mrq/DL-Art-School
Add a way to disable grad on portions of the generator graph to save memory
This commit is contained in:
parent
e3adafbeac
commit
c74b9ee2e4
|
@ -41,6 +41,7 @@ class GrowingSRGBase(nn.Module):
|
||||||
self.start_step = start_step
|
self.start_step = start_step
|
||||||
self.latest_step = start_step
|
self.latest_step = start_step
|
||||||
self.fades = []
|
self.fades = []
|
||||||
|
self.counter = 0
|
||||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||||
|
|
||||||
switches = []
|
switches = []
|
||||||
|
@ -105,7 +106,28 @@ class GrowingSRGBase(nn.Module):
|
||||||
# The base param group starts at step 0, the rest are defined via progressive_switches.
|
# The base param group starts at step 0, the rest are defined via progressive_switches.
|
||||||
return [0] + self.progressive_schedule
|
return [0] + self.progressive_schedule
|
||||||
|
|
||||||
|
# This method turns requires_grad on and off for different switches, allowing very large models to be trained while
|
||||||
|
# using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism.
|
||||||
|
# <groups> controls the proportion of switches that are enabled. 1/groups will be enabled.
|
||||||
|
# Switches that are younger than 40000 steps are not eligible to be turned off.
|
||||||
|
def do_switched_grad(self, groups=1):
|
||||||
|
# If requires_grad is already disabled, don't bother.
|
||||||
|
if not self.initial_conv.conv.weight.requires_grad or groups == 1:
|
||||||
|
return
|
||||||
|
self.counter = (self.counter + 1) % groups
|
||||||
|
enabled = []
|
||||||
|
for i, sw in enumerate(self.progressive_switches):
|
||||||
|
if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter:
|
||||||
|
for p in sw.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
else:
|
||||||
|
enabled.append(i)
|
||||||
|
for p in sw.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
self.do_switched_grad(2)
|
||||||
|
|
||||||
x = self.initial_conv(x)
|
x = self.initial_conv(x)
|
||||||
|
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user