Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-10-11 08:33:30 -06:00
commit 7cbf4fa665
2 changed files with 27 additions and 7 deletions

View File

@ -36,6 +36,8 @@ def create_injector(opt_inject, env):
return MarginRemoval(opt_inject, env)
elif type == 'foreach':
return ForEachInjector(opt_inject, env)
elif type == 'constant':
return ConstantInjector(opt_inject, env)
else:
raise NotImplementedError
@ -220,7 +222,6 @@ class MarginRemoval(Injector):
input = state[self.input]
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
class ForEachInjector(Injector):
def __init__(self, opt, env):
@ -239,3 +240,20 @@ class ForEachInjector(Injector):
st['_in'] = inputs[:, i]
injs.append(self.injector(st)['_out'])
return {self.output: torch.stack(injs, dim=1)}
class ConstantInjector(Injector):
def __init__(self, opt, env):
super(ConstantInjector, self).__init__(opt, env)
self.constant_type = opt['constant_type']
self.dim = opt['dim']
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
def forward(self, state):
bs = state[self.like].shape[0]
dev = state[self.like].device
if self.constant_type == 'zeroes':
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
else:
raise NotImplementedError
return { self.opt['out']: out }

View File

@ -247,9 +247,10 @@ class PingPongLoss(ConfigurableLoss):
def forward(self, _, state):
fake = state[self.opt['fake']]
l_total = 0
for i in range((len(fake) - 1) // 2):
early = fake[i]
late = fake[-i]
img_count = fake.shape[1]
for i in range((img_count - 1) // 2):
early = fake[:, i]
late = fake[:, -i]
l_total += self.criterion(early, late)
if self.env['step'] % 50 == 0:
@ -262,6 +263,7 @@ class PingPongLoss(ConfigurableLoss):
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list)
for i, img in enumerate(imglist):
cnt = imglist.shape[1]
for i in range(cnt):
img = imglist[:, i]
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))