forked from mrq/DL-Art-School
Add random-crop injector
This commit is contained in:
parent
61a86a3c1e
commit
04961b91cf
|
@ -1,13 +1,12 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch.nn
|
||||
import torchvision
|
||||
from kornia.augmentation import RandomResizedCrop
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
from trainer.losses import extract_params_from_state
|
||||
from trainer.inject import Injector
|
||||
from trainer.losses import extract_params_from_state
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
|
||||
|
||||
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
||||
|
@ -372,34 +371,17 @@ class MixAndLabelInjector(Injector):
|
|||
return {self.out_labels: labels, self.output: output}
|
||||
|
||||
|
||||
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
|
||||
# using ExtensibleTrainer.
|
||||
class SaveImages(Injector):
|
||||
# Randomly performs a uniform resize & crop from a base image.
|
||||
# Never resizes below input resolution or messes with the aspect ratio.
|
||||
class RandomCropInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.logits = opt['logits']
|
||||
self.target = opt['target']
|
||||
self.thresh = opt['threshold']
|
||||
self.index = 0
|
||||
self.rindex = 0
|
||||
self.run_id = random.randint(0, 999999)
|
||||
self.savedir = opt['savedir']
|
||||
os.makedirs(self.savedir, exist_ok=True)
|
||||
self.rejectdir = opt['negatives']
|
||||
if self.rejectdir:
|
||||
os.makedirs(self.rejectdir, exist_ok=True)
|
||||
self.softmax = torch.nn.Softmax(dim=1)
|
||||
dim_in = opt['dim_in']
|
||||
dim_out = opt['dim_out']
|
||||
scale = dim_out / dim_in
|
||||
self.operator = RandomResizedCrop(size=(dim_out, dim_out), scale=(scale, 1), ratio=(1,1),
|
||||
resample='NEAREST')
|
||||
|
||||
def forward(self, state):
|
||||
logits = self.softmax(state[self.logits])
|
||||
images = state[self.input]
|
||||
bs = images.shape[0]
|
||||
for b in range(bs):
|
||||
if logits[b][self.target] > self.thresh:
|
||||
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
||||
self.index += 1
|
||||
elif self.rejectdir:
|
||||
torchvision.utils.save_image(images[b],
|
||||
os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
|
||||
self.rindex += 1
|
||||
return {}
|
||||
return {self.output: self.operator(self.input)}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user