forked from mrq/DL-Art-School
48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
|
import random
|
||
|
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
from kornia import augmentation as augs
|
||
|
from kornia import filters
|
||
|
import torch.nn as nn
|
||
|
|
||
|
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
|
||
|
# inputs. These are then outputted as 'aug1' and 'aug2'.
|
||
|
from data import create_dataset
|
||
|
|
||
|
|
||
|
class RandomApply(nn.Module):
|
||
|
def __init__(self, fn, p):
|
||
|
super().__init__()
|
||
|
self.fn = fn
|
||
|
self.p = p
|
||
|
def forward(self, x):
|
||
|
if random.random() > self.p:
|
||
|
return x
|
||
|
return self.fn(x)
|
||
|
|
||
|
|
||
|
class ByolDatasetWrapper(Dataset):
|
||
|
def __init__(self, opt):
|
||
|
super().__init__()
|
||
|
self.wrapped_dataset = create_dataset(opt['dataset'])
|
||
|
self.cropped_img_size = opt['crop_size']
|
||
|
augmentations = [ \
|
||
|
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||
|
augs.RandomGrayscale(p=0.2),
|
||
|
augs.RandomHorizontalFlip(),
|
||
|
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||
|
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||
|
if opt['normalize']:
|
||
|
# The paper calls for normalization. Recommend setting true if you want exactly like the paper.
|
||
|
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
|
||
|
self.aug = nn.Sequential(*augmentations)
|
||
|
|
||
|
def __getitem__(self, item):
|
||
|
item = self.wrapped_dataset[item]
|
||
|
item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)})
|
||
|
return item
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.wrapped_dataset)
|