Implement a few changes to support training BYOL networks
This commit is contained in:
parent
2437b33e74
commit
1bbcb96ee8
|
@ -35,6 +35,9 @@ class ByolDatasetWrapper(Dataset):
|
|||
super().__init__()
|
||||
self.wrapped_dataset = create_dataset(opt['dataset'])
|
||||
self.cropped_img_size = opt['crop_size']
|
||||
self.key1 = opt_get(opt, ['key1'], 'hq')
|
||||
self.key2 = opt_get(opt, ['key2'], 'lq')
|
||||
|
||||
augmentations = [ \
|
||||
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
|
@ -49,7 +52,7 @@ class ByolDatasetWrapper(Dataset):
|
|||
|
||||
def __getitem__(self, item):
|
||||
item = self.wrapped_dataset[item]
|
||||
item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)})
|
||||
item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)})
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -22,6 +22,9 @@ class ImageFolderDataset:
|
|||
self.scale = opt['scale']
|
||||
self.paths = opt['paths']
|
||||
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
|
||||
self.fetch_alt_image = opt['fetch_alt_image'] # If specified, this dataset will attempt to find a second image
|
||||
# from the same video source. Search for 'fetch_alt_image' for more info.
|
||||
self.skip_lq = opt['skip_lq']
|
||||
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
||||
if not isinstance(self.paths, list):
|
||||
self.paths = [self.paths]
|
||||
|
@ -110,13 +113,56 @@ class ImageFolderDataset:
|
|||
dim = hq.shape[0]
|
||||
|
||||
hs = self.resize_hq([hq])
|
||||
ls = self.synthesize_lq(hs)
|
||||
if not self.skip_lq:
|
||||
ls = self.synthesize_lq(hs)
|
||||
|
||||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
|
||||
if not self.skip_lq:
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
|
||||
|
||||
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
||||
if not self.skip_lq:
|
||||
out_dict['lq'] = lq
|
||||
|
||||
if self.fetch_alt_image:
|
||||
# This works by assuming a specific filename structure as would produced by ffmpeg. ex:
|
||||
# 'Candied Walnutsxjktqhr_SYc.webm_00000478.jpg` and
|
||||
# 'Candied Walnutsxjktqhr_SYc.webm_00000479.jpg` and
|
||||
# 'Candied Walnutsxjktqhr_SYc.webm_00000480.jpg`
|
||||
# The basic format is `<anything>%08d.<extension>`. This logic parses off that 8 digit number. If it is
|
||||
# not found, the 'alt_image' returned is just the current image. If it is found, the algorithm searches for
|
||||
# an image one number higher. If it is found - it is returned in the 'alt_hq' and 'alt_lq' keys, else the
|
||||
# current image is put in those keys.
|
||||
|
||||
imname_parts = self.image_paths[item]
|
||||
while '.jpg.jpg' in imname_parts:
|
||||
imname_parts = imname_parts.replace(".jpg.jpg", ".jpg") # Hack workaround to my own bug.
|
||||
imname_parts = imname_parts.split('.')
|
||||
if len(imname_parts) >= 2 and len(imname_parts[-2]) > 8:
|
||||
try:
|
||||
imnumber = int(imname_parts[-2][-8:])
|
||||
# When we're dealing with images in the 1M range, it's straight up faster to attempt to just open
|
||||
# the file rather than searching the path list. Let the exception handler below do its work.
|
||||
next_img = self.image_paths[item].replace(str(imnumber), str(imnumber+1))
|
||||
alt_hq = util.read_img(None, next_img, rgb=True)
|
||||
alt_hq = self.resize_hq([alt_hq])
|
||||
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hq[0], (2, 0, 1)))).float()
|
||||
if not self.skip_lq:
|
||||
alt_lq = self.synthesize_lq(alt_hq)
|
||||
alt_lq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq[0], (2, 0, 1)))).float()
|
||||
except:
|
||||
alt_hq = hq
|
||||
if not self.skip_lq:
|
||||
alt_lq = lq
|
||||
else:
|
||||
alt_hq = hq
|
||||
if not self.skip_lq:
|
||||
alt_lq = lq
|
||||
out_dict['alt_hq'] = alt_hq
|
||||
if not self.skip_lq:
|
||||
out_dict['alt_lq'] = alt_lq
|
||||
|
||||
out_dict = {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
||||
if self.labeler:
|
||||
base_file = self.image_paths[item].replace(self.paths[0], "")
|
||||
while base_file.startswith("\\"):
|
||||
|
@ -131,19 +177,20 @@ class ImageFolderDataset:
|
|||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\'],
|
||||
'paths': ['F:\\4k6k\\datasets\\images\\youtube\\4k_quote_unquote\\images'],
|
||||
'weights': [1],
|
||||
'target_size': 512,
|
||||
'target_size': 256,
|
||||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
||||
'random_corruptions': ['noise-5', 'none'],
|
||||
'num_corrupts_per_image': 1,
|
||||
'corrupt_before_downsize': True,
|
||||
'labeler': {
|
||||
'type': 'patch_labels',
|
||||
'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
|
||||
}
|
||||
'fetch_alt_image': True,
|
||||
#'labeler': {
|
||||
# 'type': 'patch_labels',
|
||||
# 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
|
||||
#}
|
||||
}
|
||||
|
||||
ds = ImageFolderDataset(opt)
|
||||
|
@ -152,11 +199,11 @@ if __name__ == '__main__':
|
|||
for i in range(0, len(ds)):
|
||||
o = ds[random.randint(0, len(ds)-1)]
|
||||
hq = o['hq']
|
||||
masked = (o['labels_mask'] * .5 + .5) * hq
|
||||
#masked = (o['labels_mask'] * .5 + .5) * hq
|
||||
import torchvision
|
||||
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_hq.png" % (i,))
|
||||
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_masked.png" % (i,))
|
||||
if len(o['labels'].unique()) > 1:
|
||||
randlbl = np.random.choice(o['labels'].unique()[1:])
|
||||
moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
|
||||
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))
|
||||
torchvision.utils.save_image(o['alt_hq'].unsqueeze(0), "debug/%i_hq_alt.png" % (i,))
|
||||
#if len(o['labels'].unique()) > 1:
|
||||
# randlbl = np.random.choice(o['labels'].unique()[1:])
|
||||
# moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
|
||||
# torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))
|
|
@ -113,7 +113,10 @@ def read_img(env, path, size=None, rgb=False):
|
|||
"""read image by cv2 or from lmdb or from a buffer (in which case path=buffer)
|
||||
return: Numpy float32, HWC, BGR, [0,1]"""
|
||||
if env is None: # img
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
# Indirect open then process to support unicode files.
|
||||
stream = open(path, "rb")
|
||||
bytes = bytearray(stream.read())
|
||||
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
elif env is 'lmdb':
|
||||
img = _read_img_lmdb(env, path, size)
|
||||
elif env is 'buffer':
|
||||
|
|
|
@ -148,7 +148,7 @@ class NetWrapper(nn.Module):
|
|||
if self.structural_mlp:
|
||||
projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size)
|
||||
else:
|
||||
_, dim = hidden.shape
|
||||
_, dim = hidden.flatten(1,-1).shape
|
||||
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
|
|
|
@ -13,15 +13,16 @@ import torch
|
|||
def main():
|
||||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 10
|
||||
opt['n_thread'] = 8
|
||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
|
||||
opt['dest'] = 'file'
|
||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\other_ns']
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised'
|
||||
opt['imgsize'] = 512
|
||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised']
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
|
||||
opt['imgsize'] = 256
|
||||
#opt['bottom_crop'] = 120
|
||||
|
||||
save_folder = opt['save_folder']
|
||||
if not osp.exists(save_folder):
|
||||
|
@ -36,6 +37,7 @@ class TiledDataset(data.Dataset):
|
|||
self.opt = opt
|
||||
input_folder = opt['input_folder']
|
||||
self.images = data_util.get_image_paths('img', input_folder)[0]
|
||||
print("Found %i images" % (len(self.images),))
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.get(index)
|
||||
|
@ -43,7 +45,7 @@ class TiledDataset(data.Dataset):
|
|||
def get(self, index):
|
||||
path = self.images[index]
|
||||
basename = osp.basename(path)
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
img = data_util.read_img(None, path)
|
||||
|
||||
# Greyscale not supported.
|
||||
if img is None:
|
||||
|
@ -51,6 +53,12 @@ class TiledDataset(data.Dataset):
|
|||
return None
|
||||
if len(img.shape) == 2:
|
||||
return None
|
||||
|
||||
# Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to
|
||||
# consider these areas of the image.
|
||||
if 'bottom_crop' in self.opt.keys():
|
||||
img = img[:-self.opt['bottom_crop'], :, :]
|
||||
|
||||
h, w, c = img.shape
|
||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||
if min(h,w) < 512:
|
||||
|
@ -61,7 +69,14 @@ class TiledDataset(data.Dataset):
|
|||
# Crop the image so that only the center is left, since this is often the most salient part of the image.
|
||||
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
||||
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
|
||||
# I was having some issues with unicode filenames with cv2. Hence using PIL.
|
||||
# cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = (img * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img)
|
||||
img.save(osp.join(self.opt['save_folder'], basename + ".jpg"), "JPEG", quality=self.opt['compression_level'], optimize=True)
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_glean.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -125,6 +125,8 @@ def define_G(opt, opt_net, scale=None):
|
|||
from models.spinenet_arch import SpinenetWithLogits
|
||||
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
elif which_model == 'resnet52':
|
||||
netG = torchvision.models.resnet50(pretrained=opt_net['pretrained'])
|
||||
elif which_model == 'glean':
|
||||
from models.glean.glean import GleanGenerator
|
||||
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user