Misc script fixes
This commit is contained in:
parent
9dc3c8f0ff
commit
a777c1e4f9
|
@ -859,7 +859,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
|||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
@ -880,11 +880,11 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
|
|||
def forward(self, net, state):
|
||||
w_styles = state[self.w_styles]
|
||||
gen = state[self.gen]
|
||||
from models.archs.stylegan.stylegan2_lucidrains import calc_pl_lengths
|
||||
from models.stylegan.stylegan2_lucidrains import calc_pl_lengths
|
||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||
|
||||
from models.archs.stylegan.stylegan2_lucidrains import is_empty
|
||||
from models.stylegan.stylegan2_lucidrains import is_empty
|
||||
if not is_empty(self.pl_mean):
|
||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||
if not torch.isnan(pl_loss):
|
||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_quality_detectors/train_resnet_blur.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_6bl_opt.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -446,9 +446,13 @@ class SaveImages(Injector):
|
|||
self.target = opt['target']
|
||||
self.thresh = opt['threshold']
|
||||
self.index = 0
|
||||
self.rindex = 0
|
||||
self.run_id = random.randint(0, 999999)
|
||||
self.savedir = opt['savedir']
|
||||
os.makedirs(self.savedir, exist_ok=True)
|
||||
self.rejectdir = opt['negatives']
|
||||
if self.rejectdir:
|
||||
os.makedirs(self.rejectdir, exist_ok=True)
|
||||
self.softmax = torch.nn.Softmax(dim=1)
|
||||
|
||||
def forward(self, state):
|
||||
|
@ -459,4 +463,7 @@ class SaveImages(Injector):
|
|||
if logits[b][self.target] > self.thresh:
|
||||
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
||||
self.index += 1
|
||||
elif self.rejectdir:
|
||||
torchvision.utils.save_image(images[b], os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
|
||||
self.rindex += 1
|
||||
return {}
|
Loading…
Reference in New Issue
Block a user