Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
7cbf4fa665
|
@ -36,6 +36,8 @@ def create_injector(opt_inject, env):
|
|||
return MarginRemoval(opt_inject, env)
|
||||
elif type == 'foreach':
|
||||
return ForEachInjector(opt_inject, env)
|
||||
elif type == 'constant':
|
||||
return ConstantInjector(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -220,7 +222,6 @@ class MarginRemoval(Injector):
|
|||
input = state[self.input]
|
||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||
|
||||
|
||||
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
|
||||
class ForEachInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
|
@ -238,4 +239,21 @@ class ForEachInjector(Injector):
|
|||
for i in range(inputs.shape[1]):
|
||||
st['_in'] = inputs[:, i]
|
||||
injs.append(self.injector(st)['_out'])
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
|
||||
|
||||
class ConstantInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ConstantInjector, self).__init__(opt, env)
|
||||
self.constant_type = opt['constant_type']
|
||||
self.dim = opt['dim']
|
||||
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
||||
|
||||
def forward(self, state):
|
||||
bs = state[self.like].shape[0]
|
||||
dev = state[self.like].device
|
||||
if self.constant_type == 'zeroes':
|
||||
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return { self.opt['out']: out }
|
||||
|
|
|
@ -247,9 +247,10 @@ class PingPongLoss(ConfigurableLoss):
|
|||
def forward(self, _, state):
|
||||
fake = state[self.opt['fake']]
|
||||
l_total = 0
|
||||
for i in range((len(fake) - 1) // 2):
|
||||
early = fake[i]
|
||||
late = fake[-i]
|
||||
img_count = fake.shape[1]
|
||||
for i in range((img_count - 1) // 2):
|
||||
early = fake[:, i]
|
||||
late = fake[:, -i]
|
||||
l_total += self.criterion(early, late)
|
||||
|
||||
if self.env['step'] % 50 == 0:
|
||||
|
@ -262,6 +263,7 @@ class PingPongLoss(ConfigurableLoss):
|
|||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
assert isinstance(imglist, list)
|
||||
for i, img in enumerate(imglist):
|
||||
cnt = imglist.shape[1]
|
||||
for i in range(cnt):
|
||||
img = imglist[:, i]
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||
|
|
Loading…
Reference in New Issue
Block a user