From 931ed903c1c9bea8f9b8ff30a38b492f9abb5af6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 15 Nov 2020 16:16:18 -0700 Subject: [PATCH] Allow combined additive loss --- codes/models/steps/injectors.py | 7 ++++++- codes/train2.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 5d584980..333601c5 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -392,13 +392,18 @@ class SrDiffsInjector(Injector): self.hq = opt['hq'] if self.mode == 'produce_diff': self.diff_key = opt['diff'] + self.include_combined = opt['include_combined'] def forward(self, state): resampled_lq = state[self.lq] hq = state[self.hq] if self.mode == 'produce_diff': diff = hq - resampled_lq - return {self.output: torch.cat([resampled_lq, diff], dim=1), + if self.include_combined: + res = torch.cat([resampled_lq, diff, hq], dim=1) + else: + res = torch.cat([resampled_lq, diff], dim=1) + return {self.output: res, self.diff_key: diff} elif self.mode == 'recombine': combined = resampled_lq + hq diff --git a/codes/train2.py b/codes/train2.py index bcc65a45..5f964370 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()