Allow combined additive loss
This commit is contained in:
parent
98eada1e4c
commit
931ed903c1
|
@ -392,13 +392,18 @@ class SrDiffsInjector(Injector):
|
|||
self.hq = opt['hq']
|
||||
if self.mode == 'produce_diff':
|
||||
self.diff_key = opt['diff']
|
||||
self.include_combined = opt['include_combined']
|
||||
|
||||
def forward(self, state):
|
||||
resampled_lq = state[self.lq]
|
||||
hq = state[self.hq]
|
||||
if self.mode == 'produce_diff':
|
||||
diff = hq - resampled_lq
|
||||
return {self.output: torch.cat([resampled_lq, diff], dim=1),
|
||||
if self.include_combined:
|
||||
res = torch.cat([resampled_lq, diff, hq], dim=1)
|
||||
else:
|
||||
res = torch.cat([resampled_lq, diff], dim=1)
|
||||
return {self.output: res,
|
||||
self.diff_key: diff}
|
||||
elif self.mode == 'recombine':
|
||||
combined = resampled_lq + hq
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user