Constant injector and teco fixes

This commit is contained in:
James Betker 2020-10-11 08:20:07 -06:00
parent e785029936
commit a9c2e97391
2 changed files with 27 additions and 7 deletions

View File

@ -36,6 +36,8 @@ def create_injector(opt_inject, env):
return MarginRemoval(opt_inject, env) return MarginRemoval(opt_inject, env)
elif type == 'foreach': elif type == 'foreach':
return ForEachInjector(opt_inject, env) return ForEachInjector(opt_inject, env)
elif type == 'constant':
return ConstantInjector(opt_inject, env)
else: else:
raise NotImplementedError raise NotImplementedError
@ -220,7 +222,6 @@ class MarginRemoval(Injector):
input = state[self.input] input = state[self.input]
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
# Produces an injection which is composed of applying a single injector multiple times across a single dimension. # Produces an injection which is composed of applying a single injector multiple times across a single dimension.
class ForEachInjector(Injector): class ForEachInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
@ -238,4 +239,21 @@ class ForEachInjector(Injector):
for i in range(inputs.shape[1]): for i in range(inputs.shape[1]):
st['_in'] = inputs[:, i] st['_in'] = inputs[:, i]
injs.append(self.injector(st)['_out']) injs.append(self.injector(st)['_out'])
return {self.output: torch.stack(injs, dim=1)} return {self.output: torch.stack(injs, dim=1)}
class ConstantInjector(Injector):
def __init__(self, opt, env):
super(ConstantInjector, self).__init__(opt, env)
self.constant_type = opt['constant_type']
self.dim = opt['dim']
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
def forward(self, state):
bs = state[self.like].shape[0]
dev = state[self.like].device
if self.constant_type == 'zeroes':
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
else:
raise NotImplementedError
return { self.opt['out']: out }

View File

@ -247,9 +247,10 @@ class PingPongLoss(ConfigurableLoss):
def forward(self, _, state): def forward(self, _, state):
fake = state[self.opt['fake']] fake = state[self.opt['fake']]
l_total = 0 l_total = 0
for i in range((len(fake) - 1) // 2): img_count = fake.shape[1]
early = fake[i] for i in range((img_count - 1) // 2):
late = fake[-i] early = fake[:, i]
late = fake[:, -i]
l_total += self.criterion(early, late) l_total += self.criterion(early, late)
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
@ -262,6 +263,7 @@ class PingPongLoss(ConfigurableLoss):
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list) cnt = imglist.shape[1]
for i, img in enumerate(imglist): for i in range(cnt):
img = imglist[:, i]
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))