Allow combined additive loss

This commit is contained in:
James Betker 2020-11-15 16:16:18 -07:00
parent 98eada1e4c
commit 931ed903c1
2 changed files with 7 additions and 2 deletions

View File

@ -392,13 +392,18 @@ class SrDiffsInjector(Injector):
self.hq = opt['hq']
if self.mode == 'produce_diff':
self.diff_key = opt['diff']
self.include_combined = opt['include_combined']
def forward(self, state):
resampled_lq = state[self.lq]
hq = state[self.hq]
if self.mode == 'produce_diff':
diff = hq - resampled_lq
return {self.output: torch.cat([resampled_lq, diff], dim=1),
if self.include_combined:
res = torch.cat([resampled_lq, diff, hq], dim=1)
else:
res = torch.cat([resampled_lq, diff], dim=1)
return {self.output: res,
self.diff_key: diff}
elif self.mode == 'recombine':
combined = resampled_lq + hq

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()