diff --git a/codes/models/archs/ProgressiveSrg_arch.py b/codes/models/archs/ProgressiveSrg_arch.py index ee0043f8..66d48586 100644 --- a/codes/models/archs/ProgressiveSrg_arch.py +++ b/codes/models/archs/ProgressiveSrg_arch.py @@ -41,6 +41,7 @@ class GrowingSRGBase(nn.Module): self.start_step = start_step self.latest_step = start_step self.fades = [] + self.counter = 0 assert self.upsample_factor == 2 or self.upsample_factor == 4 switches = [] @@ -105,7 +106,28 @@ class GrowingSRGBase(nn.Module): # The base param group starts at step 0, the rest are defined via progressive_switches. return [0] + self.progressive_schedule + # This method turns requires_grad on and off for different switches, allowing very large models to be trained while + # using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism. + # controls the proportion of switches that are enabled. 1/groups will be enabled. + # Switches that are younger than 40000 steps are not eligible to be turned off. + def do_switched_grad(self, groups=1): + # If requires_grad is already disabled, don't bother. + if not self.initial_conv.conv.weight.requires_grad or groups == 1: + return + self.counter = (self.counter + 1) % groups + enabled = [] + for i, sw in enumerate(self.progressive_switches): + if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter: + for p in sw.parameters(): + p.requires_grad = False + else: + enabled.append(i) + for p in sw.parameters(): + p.requires_grad = True + def forward(self, x): + self.do_switched_grad(2) + x = self.initial_conv(x) self.attentions = []