Misc script fixes

This commit is contained in:
James Betker 2020-12-29 20:25:09 -07:00
parent 9dc3c8f0ff
commit a777c1e4f9
3 changed files with 11 additions and 4 deletions

View File

@ -859,7 +859,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
# Apply gradient penalty. TODO: migrate this elsewhere. # Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0: if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty from models.stylegan.stylegan2_lucidrains import gradient_penalty
gp = gradient_penalty(real_input, real) gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach())) self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp divergence_loss = divergence_loss + gp
@ -880,11 +880,11 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
def forward(self, net, state): def forward(self, net, state):
w_styles = state[self.w_styles] w_styles = state[self.w_styles]
gen = state[self.gen] gen = state[self.gen]
from models.archs.stylegan.stylegan2_lucidrains import calc_pl_lengths from models.stylegan.stylegan2_lucidrains import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen) pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.archs.stylegan.stylegan2_lucidrains import is_empty from models.stylegan.stylegan2_lucidrains import is_empty
if not is_empty(self.pl_mean): if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss): if not torch.isnan(pl_loss):

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -446,9 +446,13 @@ class SaveImages(Injector):
self.target = opt['target'] self.target = opt['target']
self.thresh = opt['threshold'] self.thresh = opt['threshold']
self.index = 0 self.index = 0
self.rindex = 0
self.run_id = random.randint(0, 999999) self.run_id = random.randint(0, 999999)
self.savedir = opt['savedir'] self.savedir = opt['savedir']
os.makedirs(self.savedir, exist_ok=True) os.makedirs(self.savedir, exist_ok=True)
self.rejectdir = opt['negatives']
if self.rejectdir:
os.makedirs(self.rejectdir, exist_ok=True)
self.softmax = torch.nn.Softmax(dim=1) self.softmax = torch.nn.Softmax(dim=1)
def forward(self, state): def forward(self, state):
@ -459,4 +463,7 @@ class SaveImages(Injector):
if logits[b][self.target] > self.thresh: if logits[b][self.target] > self.thresh:
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg')) torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
self.index += 1 self.index += 1
elif self.rejectdir:
torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
self.rindex += 1
return {} return {}