diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index fbfb3e3f..bb9bbbac 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -543,9 +543,9 @@ class SwitchedSpsrWithRef2(nn.Module): attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=False) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)]) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True) self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] self.attentions = None diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 7d845029..9e8080be 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -95,10 +95,15 @@ class ConfigurableStep(Module): local_state.update(injected) new_state.update(injected) - if train: + if train and len(self.losses) > 0: # Finally, compute the losses. total_loss = 0 for loss_name, loss in self.losses.items(): + # Some losses only activate after a set number of steps. For example, proto-discriminator losses can + # be very disruptive to a generator. + if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: + continue + l = loss(self.training_net, local_state) total_loss += l * self.weights[loss_name] # Record metrics. diff --git a/codes/train.py b/codes/train.py index f993f4e9..661e9b92 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)