From a9c2e973910ad7acda866ec6cef10756d78237eb Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 11 Oct 2020 08:20:07 -0600 Subject: [PATCH] Constant injector and teco fixes --- codes/models/steps/injectors.py | 22 ++++++++++++++++++++-- codes/models/steps/tecogan_losses.py | 12 +++++++----- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 9ef6fb1e..60cb94af 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -36,6 +36,8 @@ def create_injector(opt_inject, env): return MarginRemoval(opt_inject, env) elif type == 'foreach': return ForEachInjector(opt_inject, env) + elif type == 'constant': + return ConstantInjector(opt_inject, env) else: raise NotImplementedError @@ -220,7 +222,6 @@ class MarginRemoval(Injector): input = state[self.input] return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} - # Produces an injection which is composed of applying a single injector multiple times across a single dimension. class ForEachInjector(Injector): def __init__(self, opt, env): @@ -238,4 +239,21 @@ class ForEachInjector(Injector): for i in range(inputs.shape[1]): st['_in'] = inputs[:, i] injs.append(self.injector(st)['_out']) - return {self.output: torch.stack(injs, dim=1)} + return {self.output: torch.stack(injs, dim=1)} + + +class ConstantInjector(Injector): + def __init__(self, opt, env): + super(ConstantInjector, self).__init__(opt, env) + self.constant_type = opt['constant_type'] + self.dim = opt['dim'] + self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use. + + def forward(self, state): + bs = state[self.like].shape[0] + dev = state[self.like].device + if self.constant_type == 'zeroes': + out = torch.zeros((bs,) + tuple(self.dim), device=dev) + else: + raise NotImplementedError + return { self.opt['out']: out } diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index e2717d94..304e58e8 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -247,9 +247,10 @@ class PingPongLoss(ConfigurableLoss): def forward(self, _, state): fake = state[self.opt['fake']] l_total = 0 - for i in range((len(fake) - 1) // 2): - early = fake[i] - late = fake[-i] + img_count = fake.shape[1] + for i in range((img_count - 1) // 2): + early = fake[:, i] + late = fake[:, -i] l_total += self.criterion(early, late) if self.env['step'] % 50 == 0: @@ -262,6 +263,7 @@ class PingPongLoss(ConfigurableLoss): return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) - assert isinstance(imglist, list) - for i, img in enumerate(imglist): + cnt = imglist.shape[1] + for i in range(cnt): + img = imglist[:, i] torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))