diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py index d58221e8..fa4040a5 100644 --- a/codes/models/stylegan/stylegan2_lucidrains.py +++ b/codes/models/stylegan/stylegan2_lucidrains.py @@ -859,7 +859,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss): # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: - from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty + from models.stylegan.stylegan2_lucidrains import gradient_penalty gp = gradient_penalty(real_input, real) self.metrics.append(("gradient_penalty", gp.clone().detach())) divergence_loss = divergence_loss + gp @@ -880,11 +880,11 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss): def forward(self, net, state): w_styles = state[self.w_styles] gen = state[self.gen] - from models.archs.stylegan.stylegan2_lucidrains import calc_pl_lengths + from models.stylegan.stylegan2_lucidrains import calc_pl_lengths pl_lengths = calc_pl_lengths(w_styles, gen) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) - from models.archs.stylegan.stylegan2_lucidrains import is_empty + from models.stylegan.stylegan2_lucidrains import is_empty if not is_empty(self.pl_mean): pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() if not torch.isnan(pl_loss): diff --git a/codes/train.py b/codes/train.py index f3ccab0a..5ad1271f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors.py b/codes/trainer/injectors.py index 36956bb7..98c2b43f 100644 --- a/codes/trainer/injectors.py +++ b/codes/trainer/injectors.py @@ -446,9 +446,13 @@ class SaveImages(Injector): self.target = opt['target'] self.thresh = opt['threshold'] self.index = 0 + self.rindex = 0 self.run_id = random.randint(0, 999999) self.savedir = opt['savedir'] os.makedirs(self.savedir, exist_ok=True) + self.rejectdir = opt['negatives'] + if self.rejectdir: + os.makedirs(self.rejectdir, exist_ok=True) self.softmax = torch.nn.Softmax(dim=1) def forward(self, state): @@ -459,4 +463,7 @@ class SaveImages(Injector): if logits[b][self.target] > self.thresh: torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg')) self.index += 1 + elif self.rejectdir: + torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg')) + self.rindex += 1 return {} \ No newline at end of file