diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index bdfdb272..1030d8c7 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -1,13 +1,12 @@ -import os import random import torch.nn -import torchvision +from kornia.augmentation import RandomResizedCrop from torch.cuda.amp import autocast -from utils.weight_scheduler import get_scheduler_for_opt -from trainer.losses import extract_params_from_state from trainer.inject import Injector +from trainer.losses import extract_params_from_state +from utils.weight_scheduler import get_scheduler_for_opt # Uses a generator to synthesize an image from [in] and injects the results into [out] @@ -372,34 +371,17 @@ class MixAndLabelInjector(Injector): return {self.out_labels: labels, self.output: output} -# Doesn't inject. Rather saves images that meet a specified criteria. Useful for performing classification filtering -# using ExtensibleTrainer. -class SaveImages(Injector): +# Randomly performs a uniform resize & crop from a base image. +# Never resizes below input resolution or messes with the aspect ratio. +class RandomCropInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) - self.logits = opt['logits'] - self.target = opt['target'] - self.thresh = opt['threshold'] - self.index = 0 - self.rindex = 0 - self.run_id = random.randint(0, 999999) - self.savedir = opt['savedir'] - os.makedirs(self.savedir, exist_ok=True) - self.rejectdir = opt['negatives'] - if self.rejectdir: - os.makedirs(self.rejectdir, exist_ok=True) - self.softmax = torch.nn.Softmax(dim=1) + dim_in = opt['dim_in'] + dim_out = opt['dim_out'] + scale = dim_out / dim_in + self.operator = RandomResizedCrop(size=(dim_out, dim_out), scale=(scale, 1), ratio=(1,1), + resample='NEAREST') def forward(self, state): - logits = self.softmax(state[self.logits]) - images = state[self.input] - bs = images.shape[0] - for b in range(bs): - if logits[b][self.target] > self.thresh: - torchvision.utils.save_image(images[b], os.path.join(self.savedir, f'{self.run_id}_{self.index}.jpg')) - self.index += 1 - elif self.rejectdir: - torchvision.utils.save_image(images[b], - os.path.join(self.rejectdir, f'{self.run_id}_{self.rindex}.jpg')) - self.rindex += 1 - return {} \ No newline at end of file + return {self.output: self.operator(self.input)} +