Misc script fixes
This commit is contained in:
parent
9dc3c8f0ff
commit
a777c1e4f9
|
@ -859,7 +859,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
||||||
|
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
if self.env['step'] % self.gp_frequency == 0:
|
if self.env['step'] % self.gp_frequency == 0:
|
||||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||||
gp = gradient_penalty(real_input, real)
|
gp = gradient_penalty(real_input, real)
|
||||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
divergence_loss = divergence_loss + gp
|
divergence_loss = divergence_loss + gp
|
||||||
|
@ -880,11 +880,11 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
w_styles = state[self.w_styles]
|
w_styles = state[self.w_styles]
|
||||||
gen = state[self.gen]
|
gen = state[self.gen]
|
||||||
from models.archs.stylegan.stylegan2_lucidrains import calc_pl_lengths
|
from models.stylegan.stylegan2_lucidrains import calc_pl_lengths
|
||||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2_lucidrains import is_empty
|
from models.stylegan.stylegan2_lucidrains import is_empty
|
||||||
if not is_empty(self.pl_mean):
|
if not is_empty(self.pl_mean):
|
||||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||||
if not torch.isnan(pl_loss):
|
if not torch.isnan(pl_loss):
|
||||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -446,9 +446,13 @@ class SaveImages(Injector):
|
||||||
self.target = opt['target']
|
self.target = opt['target']
|
||||||
self.thresh = opt['threshold']
|
self.thresh = opt['threshold']
|
||||||
self.index = 0
|
self.index = 0
|
||||||
|
self.rindex = 0
|
||||||
self.run_id = random.randint(0, 999999)
|
self.run_id = random.randint(0, 999999)
|
||||||
self.savedir = opt['savedir']
|
self.savedir = opt['savedir']
|
||||||
os.makedirs(self.savedir, exist_ok=True)
|
os.makedirs(self.savedir, exist_ok=True)
|
||||||
|
self.rejectdir = opt['negatives']
|
||||||
|
if self.rejectdir:
|
||||||
|
os.makedirs(self.rejectdir, exist_ok=True)
|
||||||
self.softmax = torch.nn.Softmax(dim=1)
|
self.softmax = torch.nn.Softmax(dim=1)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
|
@ -459,4 +463,7 @@ class SaveImages(Injector):
|
||||||
if logits[b][self.target] > self.thresh:
|
if logits[b][self.target] > self.thresh:
|
||||||
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
||||||
self.index += 1
|
self.index += 1
|
||||||
|
elif self.rejectdir:
|
||||||
|
torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
|
||||||
|
self.rindex += 1
|
||||||
return {}
|
return {}
|
Loading…
Reference in New Issue
Block a user