forked from mrq/DL-Art-School
Allow combined additive loss
This commit is contained in:
parent
98eada1e4c
commit
931ed903c1
|
@ -392,13 +392,18 @@ class SrDiffsInjector(Injector):
|
||||||
self.hq = opt['hq']
|
self.hq = opt['hq']
|
||||||
if self.mode == 'produce_diff':
|
if self.mode == 'produce_diff':
|
||||||
self.diff_key = opt['diff']
|
self.diff_key = opt['diff']
|
||||||
|
self.include_combined = opt['include_combined']
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
resampled_lq = state[self.lq]
|
resampled_lq = state[self.lq]
|
||||||
hq = state[self.hq]
|
hq = state[self.hq]
|
||||||
if self.mode == 'produce_diff':
|
if self.mode == 'produce_diff':
|
||||||
diff = hq - resampled_lq
|
diff = hq - resampled_lq
|
||||||
return {self.output: torch.cat([resampled_lq, diff], dim=1),
|
if self.include_combined:
|
||||||
|
res = torch.cat([resampled_lq, diff, hq], dim=1)
|
||||||
|
else:
|
||||||
|
res = torch.cat([resampled_lq, diff], dim=1)
|
||||||
|
return {self.output: res,
|
||||||
self.diff_key: diff}
|
self.diff_key: diff}
|
||||||
elif self.mode == 'recombine':
|
elif self.mode == 'recombine':
|
||||||
combined = resampled_lq + hq
|
combined = resampled_lq + hq
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user