Add random-crop injector
This commit is contained in:
parent
61a86a3c1e
commit
04961b91cf
|
@ -1,13 +1,12 @@
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import torchvision
|
from kornia.augmentation import RandomResizedCrop
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from utils.weight_scheduler import get_scheduler_for_opt
|
|
||||||
from trainer.losses import extract_params_from_state
|
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
|
from trainer.losses import extract_params_from_state
|
||||||
|
from utils.weight_scheduler import get_scheduler_for_opt
|
||||||
|
|
||||||
|
|
||||||
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
||||||
|
@ -372,34 +371,17 @@ class MixAndLabelInjector(Injector):
|
||||||
return {self.out_labels: labels, self.output: output}
|
return {self.out_labels: labels, self.output: output}
|
||||||
|
|
||||||
|
|
||||||
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering
|
# Randomly performs a uniform resize & crop from a base image.
|
||||||
# using ExtensibleTrainer.
|
# Never resizes below input resolution or messes with the aspect ratio.
|
||||||
class SaveImages(Injector):
|
class RandomCropInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
self.logits = opt['logits']
|
dim_in = opt['dim_in']
|
||||||
self.target = opt['target']
|
dim_out = opt['dim_out']
|
||||||
self.thresh = opt['threshold']
|
scale = dim_out / dim_in
|
||||||
self.index = 0
|
self.operator = RandomResizedCrop(size=(dim_out, dim_out), scale=(scale, 1), ratio=(1,1),
|
||||||
self.rindex = 0
|
resample='NEAREST')
|
||||||
self.run_id = random.randint(0, 999999)
|
|
||||||
self.savedir = opt['savedir']
|
|
||||||
os.makedirs(self.savedir, exist_ok=True)
|
|
||||||
self.rejectdir = opt['negatives']
|
|
||||||
if self.rejectdir:
|
|
||||||
os.makedirs(self.rejectdir, exist_ok=True)
|
|
||||||
self.softmax = torch.nn.Softmax(dim=1)
|
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
logits = self.softmax(state[self.logits])
|
return {self.output: self.operator(self.input)}
|
||||||
images = state[self.input]
|
|
||||||
bs = images.shape[0]
|
|
||||||
for b in range(bs):
|
|
||||||
if logits[b][self.target] > self.thresh:
|
|
||||||
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
|
|
||||||
self.index += 1
|
|
||||||
elif self.rejectdir:
|
|
||||||
torchvision.utils.save_image(images[b],
|
|
||||||
os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
|
|
||||||
self.rindex += 1
|
|
||||||
return {}
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user