From 202eb11fdcf7335faf5e78093c2040737088a21d Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 9 Oct 2020 19:51:44 -0600 Subject: [PATCH] For element loss added --- codes/models/steps/losses.py | 19 +++++++++++++++++++ codes/train2.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 2178aa3e..9543ed8f 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -30,6 +30,8 @@ def create_loss(opt_loss, env): return RecursiveInvarianceLoss(opt_loss, env) elif type == 'recurrent': return RecurrentLoss(opt_loss, env) + elif type == 'for_element': + return ForElementLoss(opt_loss, env) else: raise NotImplementedError @@ -396,3 +398,20 @@ class RecurrentLoss(ConfigurableLoss): total_loss += self.loss(net, st) return total_loss + +# Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss. +class ForElementLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(ForElementLoss, self).__init__(opt, env) + o = opt.copy() + o['type'] = opt['subtype'] + self.index = opt['index'] + o['fake'] = '_fake' + o['real'] = '_real' + self.loss = create_loss(o, self.env) + + def forward(self, net, state): + st = state.copy() + st['_real'] = state[self.opt['real']][:, self.index] + st['_fake'] = state[self.opt['fake']][:, self.index] + return self.loss(net, st) diff --git a/codes/train2.py b/codes/train2.py index f6dd496a..404bdfd4 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()