Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
7cbf4fa665
|
@ -36,6 +36,8 @@ def create_injector(opt_inject, env):
|
||||||
return MarginRemoval(opt_inject, env)
|
return MarginRemoval(opt_inject, env)
|
||||||
elif type == 'foreach':
|
elif type == 'foreach':
|
||||||
return ForEachInjector(opt_inject, env)
|
return ForEachInjector(opt_inject, env)
|
||||||
|
elif type == 'constant':
|
||||||
|
return ConstantInjector(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -220,7 +222,6 @@ class MarginRemoval(Injector):
|
||||||
input = state[self.input]
|
input = state[self.input]
|
||||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||||
|
|
||||||
|
|
||||||
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
|
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
|
||||||
class ForEachInjector(Injector):
|
class ForEachInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
|
@ -238,4 +239,21 @@ class ForEachInjector(Injector):
|
||||||
for i in range(inputs.shape[1]):
|
for i in range(inputs.shape[1]):
|
||||||
st['_in'] = inputs[:, i]
|
st['_in'] = inputs[:, i]
|
||||||
injs.append(self.injector(st)['_out'])
|
injs.append(self.injector(st)['_out'])
|
||||||
return {self.output: torch.stack(injs, dim=1)}
|
return {self.output: torch.stack(injs, dim=1)}
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(ConstantInjector, self).__init__(opt, env)
|
||||||
|
self.constant_type = opt['constant_type']
|
||||||
|
self.dim = opt['dim']
|
||||||
|
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
bs = state[self.like].shape[0]
|
||||||
|
dev = state[self.like].device
|
||||||
|
if self.constant_type == 'zeroes':
|
||||||
|
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return { self.opt['out']: out }
|
||||||
|
|
|
@ -247,9 +247,10 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
fake = state[self.opt['fake']]
|
fake = state[self.opt['fake']]
|
||||||
l_total = 0
|
l_total = 0
|
||||||
for i in range((len(fake) - 1) // 2):
|
img_count = fake.shape[1]
|
||||||
early = fake[i]
|
for i in range((img_count - 1) // 2):
|
||||||
late = fake[-i]
|
early = fake[:, i]
|
||||||
|
late = fake[:, -i]
|
||||||
l_total += self.criterion(early, late)
|
l_total += self.criterion(early, late)
|
||||||
|
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
|
@ -262,6 +263,7 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
assert isinstance(imglist, list)
|
cnt = imglist.shape[1]
|
||||||
for i, img in enumerate(imglist):
|
for i in range(cnt):
|
||||||
|
img = imglist[:, i]
|
||||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user