forked from mrq/DL-Art-School
Directly use dataset keys
This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data
This commit is contained in:
parent
8a83b1c716
commit
11155aead4
|
@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building
|
|||
1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their
|
||||
filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this
|
||||
dataset are grouped together with a temporal dimension for working with video data.
|
||||
1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions
|
||||
on those images like the above.
|
||||
1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares
|
||||
from the source image and resizing them to the target size recursively until the native resolution is hit. Each
|
||||
recursive step decreases the crop size by a factor of 2.
|
||||
1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images
|
||||
and reformats them in a way that the DLAS trainer understands.
|
||||
1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have
|
||||
generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation
|
||||
and test so I keep it around.
|
||||
generally stopped using this for performance reasons and it should be considered deprecated.
|
||||
|
||||
## Information about the "chunked" format
|
||||
|
||||
|
|
|
@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset):
|
|||
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
|
||||
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
|
||||
|
||||
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
|
||||
d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
|
||||
'lq_center': lq_center, 'gt_center': gt_center,
|
||||
'LQ_path': LQ_path, 'GT_path': full_path}
|
||||
return d
|
||||
|
|
|
@ -95,7 +95,7 @@ class ImageFolderDataset:
|
|||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
|
||||
|
||||
return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]}
|
||||
return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -118,7 +118,7 @@ if __name__ == '__main__':
|
|||
for i in range(0, len(ds)):
|
||||
o = ds[random.randint(0, len(ds))]
|
||||
#for k, v in o.items():
|
||||
k = 'LQ'
|
||||
k = 'lq'
|
||||
v = o[k]
|
||||
#if 'LQ' in k and 'path' not in k and 'center' not in k:
|
||||
#if 'full' in k:
|
||||
|
|
|
@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
|||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||
|
||||
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||
|
||||
|
||||
|
@ -83,7 +83,7 @@ if __name__ == '__main__':
|
|||
batch = None
|
||||
for i in range(len(ds)):
|
||||
import random
|
||||
k = 'LQ'
|
||||
k = 'lq'
|
||||
element = ds[random.randint(0,len(ds))]
|
||||
base_file = osp.basename(element["GT_path"])
|
||||
o = element[k].unsqueeze(0)
|
||||
|
|
|
@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset):
|
|||
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
|
||||
patches_lq = torch.stack(patches_lq, dim=0)
|
||||
|
||||
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
|
||||
d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path}
|
||||
return d
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
|||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||
|
||||
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||
|
||||
|
||||
|
@ -68,7 +68,7 @@ if __name__ == '__main__':
|
|||
batch = None
|
||||
for i in range(len(ds)):
|
||||
import random
|
||||
k = 'LQ'
|
||||
k = 'lq'
|
||||
element = ds[random.randint(0,len(ds))]
|
||||
base_file = osp.basename(element["GT_path"])
|
||||
o = element[k].unsqueeze(0)
|
||||
|
|
|
@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
|
|||
lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
|
||||
|
||||
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
|
||||
'LQ_path': path, 'GT_path': path}
|
||||
|
||||
|
@ -62,7 +62,7 @@ if __name__ == '__main__':
|
|||
for i in range(0, len(ds)):
|
||||
o = ds[random.randint(0, len(ds))]
|
||||
#for k, v in o.items():
|
||||
k = 'LQ'
|
||||
k = 'lq'
|
||||
v = o[k]
|
||||
#if 'LQ' in k and 'path' not in k and 'center' not in k:
|
||||
#if 'full' in k:
|
||||
|
|
|
@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset):
|
|||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
img = self.transform(img)
|
||||
return {'LQ': img, 'GT': img, 'GT_path': str(path)}
|
||||
return {'lq': img, 'hq': img, 'GT_path': str(path)}
|
||||
|
|
|
@ -24,7 +24,7 @@ class TorchDataset(Dataset):
|
|||
|
||||
def __getitem__(self, item):
|
||||
underlying_item = self.dataset[item][0]
|
||||
return {'LQ': underlying_item, 'GT': underlying_item,
|
||||
return {'lq': underlying_item, 'hq': underlying_item,
|
||||
'LQ_path': str(item), 'GT_path': str(item)}
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -4,7 +4,6 @@ import os
|
|||
import torch
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
from apex.parallel import DistributedDataParallel
|
||||
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
import models.networks as networks
|
||||
|
@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||
for anet in all_networks:
|
||||
if opt['dist']:
|
||||
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
|
||||
from apex.parallel import DistributedDataParallel
|
||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
else:
|
||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||
|
@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
o.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)]
|
||||
if need_GT:
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)]
|
||||
else:
|
||||
self.hq = self.lq
|
||||
self.ref = self.lq
|
||||
|
||||
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
self.dstate = {}
|
||||
for k, v in data.items():
|
||||
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor):
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
|
@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
# Conforms to an archaic format from MMSR.
|
||||
return {'LQ': self.eval_state['lq'][0].float().cpu(),
|
||||
'GT': self.eval_state['hq'][0].float().cpu(),
|
||||
return {'lq': self.eval_state['lq'][0].float().cpu(),
|
||||
'hq': self.eval_state['hq'][0].float().cpu(),
|
||||
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||
|
||||
def print_network(self):
|
||||
|
|
|
@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator):
|
|||
print("Evaluating FlowGaussianNll..")
|
||||
for batch in tqdm(self.dataloader):
|
||||
dev = self.env['device']
|
||||
z, _, _ = self.model(gt=batch['GT'].to(dev),
|
||||
lr=batch['LQ'].to(dev),
|
||||
z, _, _ = self.model(gt=batch['hq'].to(dev),
|
||||
lr=batch['lq'].to(dev),
|
||||
epses=[],
|
||||
reverse=False,
|
||||
add_gt_noise=False)
|
||||
|
|
|
@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
|
|||
counter = 0
|
||||
for batch in self.sampler:
|
||||
noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
||||
batch_hq = [e['GT'] for e in batch]
|
||||
batch_hq = [e['hq'] for e in batch]
|
||||
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
|
||||
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
|
||||
embedding = embedding_generator(resized_batch)
|
||||
|
|
|
@ -66,9 +66,9 @@ class FeatureModel(BaseModel):
|
|||
self.log_dict = OrderedDict()
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
self.var_L = data['LQ'].to(self.device) # LQ
|
||||
self.var_L = data['lq'].to(self.device) # LQ
|
||||
if need_GT:
|
||||
self.real_H = data['GT'].to(self.device) # GT
|
||||
self.real_H = data['hq'].to(self.device) # GT
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
self.optimizer_G.zero_grad()
|
||||
|
|
|
@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset):
|
|||
img_LQ = lq_template
|
||||
ref = ref_template
|
||||
|
||||
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
||||
return {'lq': img_LQ, 'lq_fullsize_ref': ref,
|
||||
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
||||
|
||||
def __len__(self):
|
||||
|
@ -159,8 +159,8 @@ if __name__ == "__main__":
|
|||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
|
||||
if recurrent_mode and first_frame:
|
||||
b, c, h, w = data['LQ'].shape
|
||||
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device)
|
||||
b, c, h, w = data['lq'].shape
|
||||
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device)
|
||||
# Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of.
|
||||
if 'recurrent_hr_generator' in opt.keys():
|
||||
recurrent_gen = model.env['generators']['generator']
|
||||
|
|
|
@ -40,8 +40,8 @@ if __name__ == '__main__':
|
|||
break
|
||||
sampled += 1
|
||||
|
||||
im = rgb2ycbcr(train_data['GT'].double())
|
||||
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
|
||||
im = rgb2ycbcr(train_data['hq'].double())
|
||||
im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
|
||||
size=im.shape[2:],
|
||||
mode="bicubic", align_corners=False))
|
||||
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||
|
|
|
@ -16,7 +16,7 @@ import utils.util as util # noqa: E402
|
|||
|
||||
def main():
|
||||
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
|
||||
mode = 'GT' # used for vimeo90k and REDS datasets
|
||||
mode = 'hq' # used for vimeo90k and REDS datasets
|
||||
# vimeo90k: GT | LR | flow
|
||||
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
|
||||
# train_sharp_flowx4
|
||||
|
@ -159,7 +159,7 @@ def vimeo90k(mode):
|
|||
read_all_imgs = False # whether real all images to memory with multiprocessing
|
||||
# Set False for use limited memory
|
||||
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
|
||||
if mode == 'GT':
|
||||
if mode == 'hq':
|
||||
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
|
||||
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
||||
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
|
||||
|
@ -204,7 +204,7 @@ def vimeo90k(mode):
|
|||
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
|
||||
all_img_list = sorted(all_img_list)
|
||||
keys = sorted(keys)
|
||||
if mode == 'GT': # only read the 4th frame for the GT mode
|
||||
if mode == 'hq': # only read the 4th frame for the GT mode
|
||||
print('Only keep the 4th frame.')
|
||||
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
|
||||
keys = [v for v in keys if v.endswith('_4')]
|
||||
|
@ -255,9 +255,9 @@ def vimeo90k(mode):
|
|||
|
||||
#### create meta information
|
||||
meta_info = {}
|
||||
if mode == 'GT':
|
||||
if mode == 'hq':
|
||||
meta_info['name'] = 'Vimeo90K_train_GT'
|
||||
elif mode == 'LR':
|
||||
elif mode == 'lq':
|
||||
meta_info['name'] = 'Vimeo90K_train_LR'
|
||||
elif mode == 'flow':
|
||||
meta_info['name'] = 'Vimeo90K_train_flowx4'
|
||||
|
|
|
@ -163,7 +163,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
srg_analyze = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
@ -180,10 +180,10 @@ if __name__ == "__main__":
|
|||
gen = model.networks['generator']
|
||||
gen.eval()
|
||||
|
||||
mode = "temperature" # temperature | restore | latent_transfer | feed_through
|
||||
mode = "restore" # temperature | restore | latent_transfer | feed_through
|
||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
|
||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
||||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
|
||||
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
|
||||
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
|
||||
scale = 2
|
||||
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
|
||||
temperature = 1
|
||||
|
@ -224,6 +224,11 @@ if __name__ == "__main__":
|
|||
t = image_2_tensor(img_file).to(model.env['device'])
|
||||
if resample_factor != 1:
|
||||
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
|
||||
# Ensure the input image is a factor of 16.
|
||||
_, _, h, w = t.shape
|
||||
h = 16 * (h // 16)
|
||||
w = 16 * (w // 16)
|
||||
t = t[:, :, :h, :w]
|
||||
resample_img = t
|
||||
|
||||
# Fetch the latent metrics & latents for each image we are resampling.
|
||||
|
@ -255,6 +260,6 @@ if __name__ == "__main__":
|
|||
for j in range(len(lats)):
|
||||
path = os.path.join(output_path, "%i_%i" % (im_it, j))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it)))
|
||||
torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it)))
|
||||
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
|
||||
path, [torch.zeros_like(l) for l in lats[j]], lats[j])
|
||||
|
|
|
@ -85,8 +85,8 @@ def main():
|
|||
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
||||
LQs = data['LQs']
|
||||
else:
|
||||
LQ = data['LQ']
|
||||
GT = data['GT']
|
||||
LQ = data['lq']
|
||||
GT = data['hq']
|
||||
|
||||
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
||||
for j in range(LQs.size(1)):
|
||||
|
|
|
@ -68,6 +68,6 @@ if __name__ == "__main__":
|
|||
# removed += 1
|
||||
imname = osp.basename(data['GT_path'][i])
|
||||
if results[i]-dataset_mean > 1:
|
||||
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
|
||||
torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
||||
|
||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
|
@ -66,7 +66,7 @@ if __name__ == "__main__":
|
|||
model.test()
|
||||
gen = model.eval_state['gen'][0].to(model.env['device'])
|
||||
feagen = netF(gen)
|
||||
feareal = netF(data['GT'].to(model.env['device']))
|
||||
feareal = netF(data['hq'].to(model.env['device']))
|
||||
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
|
||||
means.append(torch.mean(losses).item())
|
||||
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
|
||||
|
@ -76,6 +76,6 @@ if __name__ == "__main__":
|
|||
removed += 1
|
||||
#imname = osp.basename(data['GT_path'][i])
|
||||
#if losses[i] < 25000:
|
||||
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
|
||||
# torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
||||
|
||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
|
@ -41,9 +41,9 @@ def forward_pass(model, output_dir, alteration_suffix=''):
|
|||
save_img_path = osp.join(output_dir, img_name + '.png')
|
||||
|
||||
if need_GT:
|
||||
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
|
||||
fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
|
||||
psnr_sr = util.tensor2img(visuals[i])
|
||||
psnr_gt = util.tensor2img(data['GT'][i])
|
||||
psnr_gt = util.tensor2img(data['hq'][i])
|
||||
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
|
||||
|
||||
util.save_img(sr_img, save_img_path)
|
||||
|
|
|
@ -231,13 +231,13 @@ class Trainer:
|
|||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
# calculate PSNR
|
||||
if self.val_compute_psnr:
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
if self.val_compute_fea:
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||
|
|
|
@ -3,6 +3,8 @@ import math
|
|||
import argparse
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
@ -231,18 +233,18 @@ class Trainer:
|
|||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
# calculate PSNR
|
||||
if self.val_compute_psnr:
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
if self.val_compute_fea:
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
torchvision.utils.save_image(visuals['rlt'], save_img_path)
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_fea_loss = avg_fea_loss / idx
|
||||
|
@ -291,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user