Add random-crop injector

This commit is contained in:
James Betker 2021-01-07 12:14:55 -07:00
parent 61a86a3c1e
commit 04961b91cf

View File

@ -1,13 +1,12 @@
import os
import random import random
import torch.nn import torch.nn
import torchvision from kornia.augmentation import RandomResizedCrop
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from utils.weight_scheduler import get_scheduler_for_opt
from trainer.losses import extract_params_from_state
from trainer.inject import Injector from trainer.inject import Injector
from trainer.losses import extract_params_from_state
from utils.weight_scheduler import get_scheduler_for_opt
# Uses a generator to synthesize an image from [in] and injects the results into [out] # Uses a generator to synthesize an image from [in] and injects the results into [out]
@ -372,34 +371,17 @@ class MixAndLabelInjector(Injector):
return {self.out_labels: labels, self.output: output} return {self.out_labels: labels, self.output: output}
# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering # Randomly performs a uniform resize & crop from a base image.
# using ExtensibleTrainer. # Never resizes below input resolution or messes with the aspect ratio.
class SaveImages(Injector): class RandomCropInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super().__init__(opt, env) super().__init__(opt, env)
self.logits = opt['logits'] dim_in = opt['dim_in']
self.target = opt['target'] dim_out = opt['dim_out']
self.thresh = opt['threshold'] scale = dim_out / dim_in
self.index = 0 self.operator = RandomResizedCrop(size=(dim_out, dim_out), scale=(scale, 1), ratio=(1,1),
self.rindex = 0 resample='NEAREST')
self.run_id = random.randint(0, 999999)
self.savedir = opt['savedir']
os.makedirs(self.savedir, exist_ok=True)
self.rejectdir = opt['negatives']
if self.rejectdir:
os.makedirs(self.rejectdir, exist_ok=True)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, state): def forward(self, state):
logits = self.softmax(state[self.logits]) return {self.output: self.operator(self.input)}
images = state[self.input]
bs = images.shape[0]
for b in range(bs):
if logits[b][self.target] > self.thresh:
torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg'))
self.index += 1
elif self.rejectdir:
torchvision.utils.save_image(images[b],
os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg'))
self.rindex += 1
return {}