diff --git a/codes/data/REDS_dataset.py b/codes/data/REDS_dataset.py deleted file mode 100644 index 36f69dcd..00000000 --- a/codes/data/REDS_dataset.py +++ /dev/null @@ -1,210 +0,0 @@ -''' -REDS dataset -support reading images from lmdb, image folder and memcached -''' -import os.path as osp -import random -import pickle -import logging -import numpy as np -import cv2 -import lmdb -import torch -import torch.utils.data as data -import data.util as util -try: - import mc # import memcached -except ImportError: - pass - -logger = logging.getLogger('base') - - -class REDSDataset(data.Dataset): - ''' - Reading the training REDS dataset - key example: 000_00000000 - GT: Ground-Truth; - LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames - support reading N LQ frames, N = 1, 3, 5, 7 - ''' - - def __init__(self, opt): - super(REDSDataset, self).__init__() - self.opt = opt - # temporal augmentation - self.interval_list = opt['interval_list'] - self.random_reverse = opt['random_reverse'] - logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( - ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) - - self.half_N_frames = opt['N_frames'] // 2 - self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] - self.data_type = self.opt['data_type'] - self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs - #### directly load image keys - if self.data_type == 'lmdb': - self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) - logger.info('Using lmdb meta info for cache keys.') - elif opt['cache_keys']: - logger.info('Using cache keys: {}'.format(opt['cache_keys'])) - self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] - else: - raise ValueError( - 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') - - # remove the REDS4 for testing - self.paths_GT = [ - v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020'] - ] - assert self.paths_GT, 'Error: GT path is empty.' - - if self.data_type == 'lmdb': - self.GT_env, self.LQ_env = None, None - elif self.data_type == 'mc': # memcached - self.mclient = None - elif self.data_type == 'img': - pass - else: - raise ValueError('Wrong data type: {}'.format(self.data_type)) - - def _init_lmdb(self): - # https://github.com/chainer/chainermn/issues/129 - self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, - meminit=False) - self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, - meminit=False) - - def _ensure_memcached(self): - if self.mclient is None: - # specify the config files - server_list_config_file = None - client_config_file = None - self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, - client_config_file) - - def _read_img_mc(self, path): - ''' Return BGR, HWC, [0, 255], uint8''' - value = mc.pyvector() - self.mclient.Get(path, value) - value_buf = mc.ConvertBuffer(value) - img_array = np.frombuffer(value_buf, np.uint8) - img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) - return img - - def _read_img_mc_BGR(self, path, name_a, name_b): - ''' Read BGR channels separately and then combine for 1M limits in cluster''' - img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png')) - img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png')) - img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png')) - img = cv2.merge((img_B, img_G, img_R)) - return img - - def __getitem__(self, index): - if self.data_type == 'mc': - self._ensure_memcached() - elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): - self._init_lmdb() - - scale = self.opt['scale'] - GT_size = self.opt['target_size'] - key = self.paths_GT[index] - name_a, name_b = key.split('_') - center_frame_idx = int(name_b) - - #### determine the neighbor frames - interval = random.choice(self.interval_list) - if self.opt['border_mode']: - direction = 1 # 1: forward; 0: backward - N_frames = self.opt['N_frames'] - if self.random_reverse and random.random() < 0.5: - direction = random.choice([0, 1]) - if center_frame_idx + interval * (N_frames - 1) > 99: - direction = 0 - elif center_frame_idx - interval * (N_frames - 1) < 0: - direction = 1 - # get the neighbor list - if direction == 1: - neighbor_list = list( - range(center_frame_idx, center_frame_idx + interval * N_frames, interval)) - else: - neighbor_list = list( - range(center_frame_idx, center_frame_idx - interval * N_frames, -interval)) - name_b = '{:08d}'.format(neighbor_list[0]) - else: - # ensure not exceeding the borders - while (center_frame_idx + self.half_N_frames * interval > - 99) or (center_frame_idx - self.half_N_frames * interval < 0): - center_frame_idx = random.randint(0, 99) - # get the neighbor list - neighbor_list = list( - range(center_frame_idx - self.half_N_frames * interval, - center_frame_idx + self.half_N_frames * interval + 1, interval)) - if self.random_reverse and random.random() < 0.5: - neighbor_list.reverse() - name_b = '{:08d}'.format(neighbor_list[self.half_N_frames]) - - assert len( - neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format( - len(neighbor_list)) - - #### get the GT image (as the center frame) - if self.data_type == 'mc': - img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b) - img_GT = img_GT.astype(np.float32) / 255. - elif self.data_type == 'lmdb': - img_GT = util.read_img(self.GT_env, key, (3, 720, 1280)) - else: - img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png')) - - #### get LQ images - LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280) - img_LQ_l = [] - for v in neighbor_list: - img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v)) - if self.data_type == 'mc': - if self.LR_input: - img_LQ = self._read_img_mc(img_LQ_path) - else: - img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v)) - img_LQ = img_LQ.astype(np.float32) / 255. - elif self.data_type == 'lmdb': - img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple) - else: - img_LQ = util.read_img(None, img_LQ_path) - img_LQ_l.append(img_LQ) - - if self.opt['phase'] == 'train': - C, H, W = LQ_size_tuple # LQ size - # randomly crop - if self.LR_input: - LQ_size = GT_size // scale - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l] - rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] - else: - rnd_h = random.randint(0, max(0, H - GT_size)) - rnd_w = random.randint(0, max(0, W - GT_size)) - img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l] - img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] - - # augmentation - flip, rotate - img_LQ_l.append(img_GT) - rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) - img_LQ_l = rlt[0:-1] - img_GT = rlt[-1] - - # stack LQ images to NHWC, N is the frame number - img_LQs = np.stack(img_LQ_l, axis=0) - # BGR to RGB, HWC to CHW, numpy to tensor - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQs = img_LQs[:, :, :, [2, 1, 0]] - img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() - img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, - (0, 3, 1, 2)))).float() - return {'LQs': img_LQs, 'GT': img_GT, 'key': key} - - def __len__(self): - return len(self.paths_GT) diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index 5a1e5315..40f83302 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -1,28 +1,36 @@ import os.path as osp from data import util import torch +import numpy as np # Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates. class ChunkWithReference: def __init__(self, opt, path): self.opt = opt - self.path = path + self.path = path.path self.ref = None # This is loaded on the fly. self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else True - self.tiles = util.get_image_paths('img', path) + self.tiles, _ = util.get_image_paths('img', path) + self.centers = None def __getitem__(self, item): + # Load centers on the fly and always cache. + if self.centers is None: + self.centers = torch.load(osp.join(self.path, "centers.pt")) if self.cache_ref: if self.ref is None: - self.ref = util.read_img(None, osp.join(self.path, "ref.jpg")) + self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) self.centers = torch.load(osp.join(self.path, "centers.pt")) ref = self.ref - centers = self.centers else: - self.ref = util.read_img(None, osp.join(self.path, "ref.jpg")) - self.centers = torch.load(osp.join(self.path, "centers.pt")) + self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) + tile = util.read_img(None, self.tiles[item], rgb=True) + tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) + center, tile_width = self.centers[tile_id] + mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) + mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 - return self.tiles[item], ref, centers[item], path + return tile, ref, center, mask, self.tiles[item] def __len__(self): - return len(self.tiles) \ No newline at end of file + return len(self.tiles) diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index ddacfddb..196114f1 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -1,72 +1,92 @@ import random - +import cv2 +import numpy as np +from data.util import read_img +from PIL import Image +from io import BytesIO # Performs image corruption on a list of images from a configurable set of corruption # options. class ImageCorruptor: def __init__(self, opt): self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2 - self.corruptions_enabled = opt['corruptions'] + self.fixed_corruptions = opt['fixed_corruptions'] + self.random_corruptions = opt['random_corruptions'] def corrupt_images(self, imgs): - augmentations = random.choice(self.corruptions_enabled, k=self.num_corrupts) + augmentations = random.choices(self.random_corruptions, k=self.num_corrupts) # Source of entropy, which should be used across all images. - rand_int = random.randint(1, 999999) + rand_int_f = random.randint(1, 999999) + rand_int_a = random.randint(1, 999999) corrupted_imgs = [] for img in imgs: + for aug in self.fixed_corruptions: + img = self.apply_corruption(img, aug, rand_int_f) for aug in augmentations: - if 'color_quantization' in aug: - # Color quantization - quant_div = 2 ** random.randint(1, 4) - augmentation_tensor[AUG_TENSOR_COLOR_QUANT] = float(quant_div) / 5.0 + img = self.apply_corruption(img, aug, rand_int_a) + corrupted_imgs.append(img) - pass - elif 'gaussian_blur' in aug: - # Gaussian Blur - kernel = random.randint(1, 3) * 3 - image = cv2.GaussianBlur(image, (kernel, kernel), 3) - augmentation_tensor[AUG_TENSOR_BLUR] = float(kernel) / 9 - elif 'median_blur' in aug: - # Median Blur - kernel = random.randint(1, 3) * 3 - image = cv2.medianBlur(image, kernel) - augmentation_tensor[AUG_TENSOR_BLUR] = float(kernel) / 9 - elif 'motion_blur' in aug: - # Motion blur - intensity = random.randrange(1, 9) - image = self.motion_blur(image, intensity, random.randint(0, 360)) - augmentation_tensor[AUG_TENSOR_BLUR] = intensity / 9 - elif 'smooth_blur' in aug: - # Smooth blur - kernel = random.randint(1, 3) * 3 - image = cv2.blur(image, ksize=kernel) - augmentation_tensor[AUG_TENSOR_BLUR] = kernel / 9 - elif 'block_noise' in aug: - # Block noise - noise_intensity = random.randint(3, 10) - image += np.random.randn() - pass - elif 'lq_resampling' in aug: - # Bicubic LR->HR - pass - elif 'color_shift' in aug: - # Color shift - pass - elif 'interlacing' in aug: - # Interlacing distortion - pass - elif 'chromatic_aberration' in aug: - # Chromatic aberration - pass - elif 'noise' in aug: - # Noise - pass - elif 'jpeg' in aug: - # JPEG compression - pass - elif 'saturation' in aug: - # Lightening / saturation - pass + return corrupted_imgs - return corrupted_imgs \ No newline at end of file + def apply_corruption(self, img, aug, rand_int): + if 'color_quantization' in aug: + # Color quantization + quant_div = 2 ** ((rand_int % 3) + 2) + img = img * 255 + img = (img // quant_div) * quant_div + img = img / 255 + elif 'gaussian_blur' in aug: + # Gaussian Blur + kernel = 2 * (rand_int % 3) + 1 + img = cv2.GaussianBlur(img, (kernel, kernel), 3) + elif 'motion_blur' in aug: + # Motion blur + intensity = 2 * (rand_int % 3) + 1 + angle = (rand_int // 3) % 360 + k = np.zeros((intensity, intensity), dtype=np.float32) + k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32) + k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0), + (intensity, intensity)) + k = k * (1.0 / np.sum(k)) + img = cv2.filter2D(img, -1, k) + elif 'smooth_blur' in aug: + # Smooth blur + kernel = 2 * (rand_int % 3) + 1 + img = cv2.blur(img, ksize=(kernel, kernel)) + elif 'block_noise' in aug: + # Large distortion blocks in part of an img, such as is used to mask out a face. + pass + elif 'lq_resampling' in aug: + # Bicubic LR->HR + pass + elif 'color_shift' in aug: + # Color shift + pass + elif 'interlacing' in aug: + # Interlacing distortion + pass + elif 'chromatic_aberration' in aug: + # Chromatic aberration + pass + elif 'noise' in aug: + # Block noise + noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4 + img += np.random.randn() * noise_intensity + elif 'jpeg' in aug: + # JPEG compression + qf = (rand_int % 20 + 10) # Between 10-30 + # cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead. + img = (img * 255).astype(np.uint8) + img = Image.fromarray(img) + buffer = BytesIO() + img.save(buffer, "JPEG", quality=qf, optimice=True) + buffer.seek(0) + jpeg_img_bytes = np.asarray(bytearray(buffer.read()), dtype="uint8") + img = read_img("buffer", jpeg_img_bytes, rgb=True) + elif 'saturation' in aug: + # Lightening / saturation + saturation = float(rand_int % 10) * .03 + img = np.clip(img + saturation, a_max=1, a_min=0) + + return img diff --git a/codes/data/lmdb_dataset_with_ref.py b/codes/data/lmdb_dataset_with_ref.py deleted file mode 100644 index 1bf94796..00000000 --- a/codes/data/lmdb_dataset_with_ref.py +++ /dev/null @@ -1,231 +0,0 @@ -import random -import numpy as np -import cv2 -import torch -import torch.utils.data as data -import data.util as util -from PIL import Image, ImageOps -from io import BytesIO -import torchvision.transforms.functional as F -import lmdb -import pyarrow - - -# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to -# where those tiles came from. -class LmdbDatasetWithRef(data.Dataset): - - def __init__(self, opt): - super(LmdbDatasetWithRef, self).__init__() - self.opt = opt - self.db = lmdb.open(self.opt['lmdb_path'], subdir=True, readonly=True, lock=False, readahead=False, meminit=False) - self.data_type = 'img' - self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 - with self.db.begin(write=False) as txn: - self.keys = pyarrow.deserialize(txn.get(b'__keys__')) - self.len = pyarrow.deserialize(txn.get(b'__len__'))\ - - def motion_blur(self, image, size, angle): - k = np.zeros((size, size), dtype=np.float32) - k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32) - k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size)) - k = k * (1.0 / np.sum(k)) - return cv2.filter2D(image, -1, k) - - def resize_point(self, point, orig_dim, new_dim): - oh, ow = orig_dim - nh, nw = new_dim - dh, dw = float(nh) / float(oh), float(nw) / float(ow) - point[0] = int(dh * float(point[0])) - point[1] = int(dw * float(point[1])) - return point - - def augment_tile(self, img_GT, img_LQ, strength=1): - scale = self.opt['scale'] - GT_size = self.opt['target_size'] - - H, W, _ = img_GT.shape - assert H >= GT_size and W >= GT_size - - LQ_size = GT_size // scale - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - - if self.opt['use_blurring']: - # Pick randomly between gaussian, motion, or no blur. - blur_det = random.randint(0, 100) - blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] - blur_magnitude = max(1, int(blur_magnitude*strength)) - if blur_det < 40: - blur_sig = int(random.randrange(0, int(blur_magnitude))) - img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig) - elif blur_det < 70: - img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360)) - - return img_GT, img_LQ - - # Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it. - def pil_augment(self, img_LQ, strength=1): - img_LQ = (img_LQ * 255).astype(np.uint8) - img_LQ = Image.fromarray(img_LQ) - if self.opt['use_compression_artifacts'] and random.random() > .25: - sub_lo = 90 * strength - sub_hi = 30 * strength - qf = random.randrange(100 - sub_lo, 100 - sub_hi) - corruption_buffer = BytesIO() - img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) - corruption_buffer.seek(0) - img_LQ = Image.open(corruption_buffer) - - if 'grayscale' in self.opt.keys() and self.opt['grayscale']: - img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') - - return img_LQ - - def __getitem__(self, index): - scale = self.opt['scale'] - - # get the hq image and the ref image - key = self.keys[index] - ref_key = key[:key.index('_')] - with self.db.begin(write=False) as txn: - bytes_ref = txn.get(ref_key.encode()) - bytes_tile = txn.get(key.encode()) - unpacked_ref = pyarrow.deserialize(bytes_ref) - unpacked_tile = pyarrow.deserialize(bytes_tile) - gt_fullsize_ref = unpacked_ref[0] - img_GT, gt_center = unpacked_tile - - # TODO: synthesize gt_mask. - gt_mask = np.ones(img_GT.shape[:2]) - orig_gt_dim = gt_fullsize_ref.shape[:2] - - # Synthesize LQ by downsampling. - if self.opt['phase'] == 'train': - GT_size = self.opt['target_size'] - random_scale = random.choice(self.random_scale_list) - if len(img_GT.shape) == 2: - print("ERRAR:") - print(img_GT.shape) - print(full_path) - H_s, W_s, _ = img_GT.shape - - def _mod(n, random_scale, scale, thres): - rlt = int(n * random_scale) - rlt = (rlt // scale) * scale - return thres if rlt < thres else rlt - - H_s = _mod(H_s, random_scale, scale, GT_size) - W_s = _mod(W_s, random_scale, scale, GT_size) - img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) - if img_GT.ndim == 2: - img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) - - H, W, _ = img_GT.shape - - # using matlab imresize - img_LQ = util.imresize_np(img_GT, 1 / scale, True) - lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True) - if img_LQ.ndim == 2: - img_LQ = np.expand_dims(img_LQ, axis=2) - lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2]) - orig_lq_dim = lq_fullsize_ref.shape[:2] - - # Enforce force_resize constraints via clipping. - h, w, _ = img_LQ.shape - if h % self.force_multiple != 0 or w % self.force_multiple != 0: - h, w = (h - h % self.force_multiple), (w - w % self.force_multiple) - img_LQ = img_LQ[:h, :w, :] - lq_fullsize_ref = lq_fullsize_ref[:h, :w, :] - h *= scale - w *= scale - img_GT = img_GT[:h, :w] - gt_fullsize_ref = gt_fullsize_ref[:h, :w, :] - - if self.opt['phase'] == 'train': - img_GT, img_LQ = self.augment_tile(img_GT, img_LQ) - gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2) - - # Scale masks. - lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) - gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) - - # Scale center coords - lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2]) - gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2]) - - # BGR to RGB, HWC to CHW, numpy to tensor - if img_GT.shape[2] == 3: - img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) - img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) - lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB) - gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB) - - # LQ needs to go to a PIL image to perform the compression-artifact transformation. - if self.opt['phase'] == 'train': - img_LQ = self.pil_augment(img_LQ) - lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) - - img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() - gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float() - img_LQ = F.to_tensor(img_LQ) - lq_fullsize_ref = F.to_tensor(lq_fullsize_ref) - lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0) - gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0) - - if 'lq_noise' in self.opt.keys(): - lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255 - img_LQ += lq_noise - lq_fullsize_ref += lq_noise - - # Apply the masks to the full images. - gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) - lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) - - d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, - 'lq_center': lq_center, 'gt_center': gt_center, - 'LQ_path': key, 'GT_path': key} - return d - - def __len__(self): - return self.len - -if __name__ == '__main__': - opt = { - 'name': 'amalgam', - 'lmdb_path': 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref', - 'use_flip': True, - 'use_compression_artifacts': True, - 'use_blurring': True, - 'use_rot': True, - 'lq_noise': 5, - 'target_size': 128, - 'min_tile_size': 256, - 'scale': 2, - 'phase': 'train' - } - ''' - opt = { - 'name': 'amalgam', - 'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imagesets-lmdb-ref'], - 'dataroot_GT_weights': [1], - 'force_multiple': 32, - 'scale': 2, - 'phase': 'test' - } - ''' - - ds = LmdbDatasetWithRef(opt) - import os - os.makedirs("debug", exist_ok=True) - for i in range(300, len(ds)): - print(i) - o = ds[i] - for k, v in o.items(): - if 'path' not in k: - #if 'full' in k: - #masked = v[:3, :, :] * v[3] - #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) - #v = v[:3, :, :] - import torchvision - torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index d10c05a6..6354d3a5 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -5,6 +5,8 @@ import os from bisect import bisect_left import cv2 import torch +import numpy as np +import torchvision.transforms.functional as F # Builds a dataset composed of a set of folders. Each folder represents a single high resolution image that has been @@ -23,8 +25,17 @@ class SingleImageDataset(data.Dataset): self.weights = [1] else: self.weights = opt['weights'] + + # See if there is a cached directory listing and use that rather than re-scanning everything. This will greatly + # reduce startup costs. + self.chunks = [] for path, weight in zip(self.paths, self.weights): - chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] + cache_path = os.path.join(path, 'cache.pth') + if os.path.exists(cache_path): + chunks = torch.load(cache_path) + else: + chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] + torch.save(chunks, cache_path) for w in range(weight): self.chunks.extend(chunks) @@ -36,25 +47,18 @@ class SingleImageDataset(data.Dataset): start += len(c) self.len = start - def binary_search(elem, sorted_list): - # https://docs.python.org/3/library/bisect.html - 'Locate the leftmost value exactly equal to x' - i = bisect_left(sorted_list, elem) - if i != len(sorted_list) and sorted_list[i] == elem: - return i - return -1 def resize_point(self, point, orig_dim, new_dim): oh, ow = orig_dim nh, nw = new_dim dh, dw = float(nh) / float(oh), float(nw) / float(ow) - point[0] = int(dh * float(point[0])) - point[1] = int(dw * float(point[1])) + point = int(dh * float(point[0])), int(dw * float(point[1])) return point def __getitem__(self, item): - chunk_ind = self.binary_search(item, self.starting_indices) - hq, hq_ref, hq_center, path = self.chunks[item-self.starting_indices[chunk_ind]] + chunk_ind = bisect_left(self.starting_indices, item) + chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 + hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item-self.starting_indices[chunk_ind]] # Enforce size constraints h, w, _ = hq.shape @@ -63,6 +67,7 @@ class SingleImageDataset(data.Dataset): target_size = (self.target_hq_size, self.target_hq_size) hq = cv2.resize(hq, target_size, interpolation=cv2.INTER_LINEAR) hq_ref = cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_LINEAR) + hq_mask = cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_LINEAR) hq_center = self.resize_point(hq_center, (h, w), target_size) h, w = self.target_hq_size, self.target_hq_size hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image. @@ -71,6 +76,7 @@ class SingleImageDataset(data.Dataset): hq_center = self.resize_point(hq_center, hq.shape[:1], (h, w)) hq = hq[:h, :w, :] hq_ref = hq_ref[:h, :w, :] + hq_mask = hq_mask[:h, :w, :] # Synthesize the LQ image if self.for_eval: @@ -78,20 +84,70 @@ class SingleImageDataset(data.Dataset): else: lq = cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR) lq_ref = cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR) - lq_center = self.resize_point(hq_center, (h, w), lq.shape[:1]) + lq_mask = cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_LINEAR) + lq_center = self.resize_point(hq_center, (h, w), lq.shape[:2]) # Corrupt the LQ image - lq = self.corruptor.corrupt_images([lq]) + lq = self.corruptor.corrupt_images([lq])[0] # Convert to torch tensor hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hq, (2, 0, 1)))).float() hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(hq_ref, (2, 0, 1)))).float() - lq = F.to_tensor(lq) - lq_ref = F.to_tensor(lq_ref) + hq_mask = torch.from_numpy(np.ascontiguousarray(hq_mask)).unsqueeze(dim=0) + hq_ref = torch.cat([hq_ref, hq_mask], dim=0) + lq = torch.from_numpy(np.ascontiguousarray(np.transpose(lq, (2, 0, 1)))).float() + lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(lq_ref, (2, 0, 1)))).float() + lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0) + lq_ref = torch.cat([lq_ref, lq_mask], dim=0) return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, 'lq_center': lq_center, 'gt_center': hq_center, 'LQ_path': path, 'GT_path': path} def __len__(self): - return self.len \ No newline at end of file + return self.len + + + self.corruptor = ImageCorruptor(opt) + self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None + self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1 + self.for_eval = opt['eval'] if 'eval' in opt.keys() else False + self.scale = opt['scale'] if not self.for_eval else 1 + self.paths = opt['paths'] + if not isinstance(self.paths, list): + self.paths = [self.paths] + self.weights = [1] + else: + self.weights = opt['weights'] + for path, weight in zip(self.paths, self.weights): + chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] + for w in range(weight): + self.chunks.extend(chunks) + +if __name__ == '__main__': + opt = { + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\images\\flickr\\testbed'], + 'weights': [1], + 'target_size': 128, + 'force_multiple': 32, + 'scale': 2, + 'eval': False, + 'fixed_corruptions': ['jpeg'], + 'random_corruptions': ['color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', 'saturation'], + 'num_corrupts_per_image': 1 + } + + ds = SingleImageDataset(opt) + import os + os.makedirs("debug", exist_ok=True) + for i in range(0, len(ds)): + o = ds[i] + for k, v in o.items(): + if 'path' not in k and 'center' not in k: + #if 'full' in k: + #masked = v[:3, :, :] * v[3] + #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + #v = v[:3, :, :] + import torchvision + torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file diff --git a/codes/data/util.py b/codes/data/util.py index 7d42c0d1..f3ad9005 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -87,13 +87,18 @@ def _read_img_lmdb(env, key, size): return img -def read_img(env, path, size=None): - """read image by cv2 or from lmdb +def read_img(env, path, size=None, rgb=False): + """read image by cv2 or from lmdb or from a buffer (in which case path=buffer) return: Numpy float32, HWC, BGR, [0,1]""" if env is None: # img img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - else: + elif env is 'lmdb': img = _read_img_lmdb(env, path, size) + elif env is 'buffer': + img = cv2.imdecode(path, cv2.IMREAD_UNCHANGED) + else: + raise NotImplementedError("Unsupported env: %s" % (env,)) + if img is None: print("Image error: %s" % (path,)) img = img.astype(np.float32) / 255. @@ -102,6 +107,9 @@ def read_img(env, path, size=None): # some images have 4 channels if img.shape[2] > 3: img = img[:, :, :3] + + if rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img diff --git a/codes/data_scripts/extract_subimages_with_ref.py b/codes/data_scripts/extract_subimages_with_ref.py index 50c3c4b8..05864c98 100644 --- a/codes/data_scripts/extract_subimages_with_ref.py +++ b/codes/data_scripts/extract_subimages_with_ref.py @@ -16,17 +16,17 @@ def main(): mode = 'single' # single (one input folder) | pair (extract corresponding GT and LR pairs) split_img = False opt = {} - opt['n_thread'] = 12 + opt['n_thread'] = 0 opt['compression_level'] = 90 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. if mode == 'single': opt['dest'] = 'file' - opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_segments' - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_with_refs' + opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half' + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\new_tiles' opt['crop_sz'] = [256, 512, 1024] # the size of each sub-image - opt['step'] = 256 # step of the sliding crop window + opt['step'] = [256, 512, 1024] # step of the sliding crop window opt['thres_sz'] = 128 # size threshold opt['resize_final_img'] = [1, .5, .25] opt['only_resize'] = False @@ -158,7 +158,7 @@ class FileWriter: # Writes the given reference image to the db and returns its ID. def write_reference_image(self, ref_img, path): - ref_img, _ = ref_img # Encoded with a center point, which is irrelevant for the reference image. + ref_img, _, _ = ref_img # Encoded with a center point, which is irrelevant for the reference image. img_name = osp.basename(path).replace(".jpg", "").replace(".png", "") self.ref_center_points[img_name] = {} self.save_image(img_name, "ref.jpg", ref_img) @@ -170,14 +170,18 @@ class FileWriter: def write_tile_image(self, ref_id, tile_image): id = self.get_next_unique_id() ref_name = self.ref_ids_to_names[ref_id] - img, center = tile_image - self.ref_center_points[ref_name][id] = center + img, center, tile_sz = tile_image + self.ref_center_points[ref_name][id] = center, tile_sz self.save_image(ref_name, "%08i.jpg" % (id,), img) return id - def close(self): + def flush(self): for ref_name, cps in self.ref_center_points.items(): torch.save(cps, osp.join(self.folder, ref_name, "centers.pt")) + self.ref_center_points = {} + + def close(self): + self.flush() class TiledDataset(data.Dataset): def __init__(self, opt, split_mode=False): @@ -192,8 +196,9 @@ class TiledDataset(data.Dataset): else: return self.get(index, False, False) - def get_for_scale(self, img, split_mode, left_img, crop_sz, resize_factor): - step = self.opt['step'] + def get_for_scale(self, img, split_mode, left_image, crop_sz, step, resize_factor, ref_resize_factor): + assert not left_image # Split image not yet supported, False is the default value. + thres_sz = self.opt['thres_sz'] h, w, c = img.shape @@ -215,15 +220,14 @@ class TiledDataset(data.Dataset): for y in w_space: index += 1 crop_img = img[x:x + crop_sz, y:y + crop_sz, :] - center_point = (x + crop_sz // 2, y + crop_sz // 2) + # Center point needs to be resized by ref_resize_factor - since it is relative to the reference image. + center_point = (int((x + crop_sz // 2) // ref_resize_factor), int((y + crop_sz // 2) // ref_resize_factor)) crop_img = np.ascontiguousarray(crop_img) if 'resize_final_img' in self.opt.keys(): - # Resize too. - center_point = (int(center_point[0] * resize_factor), int(center_point[1] * resize_factor)) crop_img = cv2.resize(crop_img, dsize, interpolation=cv2.INTER_AREA) success, buffer = cv2.imencode(".jpg", crop_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) assert success - results.append((buffer, center_point)) + results.append((buffer, center_point, int(crop_sz // ref_resize_factor))) return results def get(self, index, split_mode, left_img): @@ -255,15 +259,16 @@ class TiledDataset(data.Dataset): tile_dim = int(self.opt['crop_sz'][0] * self.opt['resize_final_img'][0]) dsize = (tile_dim, tile_dim) + ref_resize_factor = h / tile_dim # Reference image should always be first entry in results. ref_img = cv2.resize(img, dsize, interpolation=cv2.INTER_AREA) success, ref_buffer = cv2.imencode(".jpg", ref_img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) assert success - results = [(ref_buffer, (-1,-1))] + results = [(ref_buffer, (-1,-1), (-1,-1))] - for crop_sz, resize_factor in zip(self.opt['crop_sz'], self.opt['resize_final_img']): - results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, resize_factor)) + for crop_sz, resize_factor, step in zip(self.opt['crop_sz'], self.opt['resize_final_img'], self.opt['step']): + results.extend(self.get_for_scale(img, split_mode, left_img, crop_sz, step, resize_factor, ref_resize_factor)) return results, path def __len__(self): @@ -286,6 +291,7 @@ def extract_single(opt, writer, split_img=False): ref_id = writer.write_reference_image(imgs[0], path) for tile in imgs[1:]: writer.write_tile_image(ref_id, tile) + writer.flush() writer.close() diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 8ca5ba43..10aab1f1 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -238,6 +238,128 @@ class SPSRNet(nn.Module): return x_out_branch, x_out, x_grad +class SwitchedSpsr(nn.Module): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): + super(SwitchedSpsr, self).__init__() + n_upscale = int(math.log(upscale, 2)) + + # switch options + transformation_filters = nf + switch_filters = nf + switch_reductions = 3 + switch_processing_layers = 2 + self.transformation_counts = xforms + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, + switch_processing_layers, self.transformation_counts, use_exp2=True) + pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=3, + weight_init_factor=.1) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + + # Grad branch + self.get_g_nopadding = ImageGradientNoPadding() + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions, + switch_processing_layers, self.transformation_counts // 2, use_exp2=True) + self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts // 2, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + # Upsampling + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + # Conv used to output grad branch shortcut. + self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) + + # Conjoin branch. + # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. + transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=4, + weight_init_factor=.1) + pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1) + self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] + self.attentions = None + self.init_temperature = init_temperature + self.final_temperature_step = 10000 + + def forward(self, x): + x_grad = self.get_g_nopadding(x) + x = self.model_fea_conv(x) + + x1, a1 = self.sw1(x, True) + x2, a2 = self.sw2(x1, True) + x_fea = self.feature_lr_conv(x2) + x_fea = self.feature_hr_conv2(x_fea) + + x_b_fea = self.b_fea_conv(x_grad) + x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True) + x_grad = self.grad_lr_conv(x_grad) + x_grad = self.grad_hr_conv(x_grad) + x_out_branch = self.upsample_grad(x_grad) + x_out_branch = self.grad_branch_output_conv(x_out_branch) + + x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) + x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True) + x_out = self.final_lr_conv(x__branch_pretrain_cat) + x_out = self.upsample(x_out) + x_out = self.final_hr_conv1(x_out) + x_out = self.final_hr_conv2(x_out) + + self.attentions = [a1, a2, a3, a4] + + return x_out_branch, x_out, x_grad + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 200 == 0: + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val + + class RefJoiner(nn.Module): def __init__(self, nf): super(RefJoiner, self).__init__() diff --git a/codes/models/networks.py b/codes/models/networks.py index 8f56ddd9..60aae360 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -49,6 +49,10 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'spsr_net_improved': netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) + elif which_model == "spsr_switched": + xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 + netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "spsr_switched_with_ref2" or which_model == "spsr3": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.SwitchedSpsrWithRef2(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],