Misc nonfunctional mods to datasets

This commit is contained in:
James Betker 2020-10-22 10:16:17 -06:00
parent 43c4f92123
commit f9dc472f63
2 changed files with 24 additions and 9 deletions

View File

@ -31,6 +31,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
if os.path.exists(cache_path): if os.path.exists(cache_path):
chunks = torch.load(cache_path) chunks = torch.load(cache_path)
else: else:
print("Building chunk cache, this can take some time for large datasets..")
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()] chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
# Prune out chunks that have no images # Prune out chunks that have no images
res = [] res = []

View File

@ -55,14 +55,14 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'paths': ['/content/fullvideo_256_tiled_test'], 'paths': ['F:\\4k6k\\datasets\\images\\fullvideo\\ge_fv_256_tiled'],
'weights': [1], 'weights': [1],
'target_size': 128, 'target_size': 128,
'force_multiple': 32, 'force_multiple': 32,
@ -77,14 +77,28 @@ if __name__ == '__main__':
ds = MultiFrameDataset(opt) ds = MultiFrameDataset(opt)
import os import os
os.makedirs("debug", exist_ok=True) os.makedirs("debug", exist_ok=True)
for i in [3]: bs = 0
batch = None
for i in range(len(ds)):
import random import random
o = ds[i] k = 'LQ'
k = 'GT' element = ds[random.randint(0,len(ds))]
v = o[k] base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0)
if bs < 32:
if batch is None:
batch = o
else:
batch = torch.cat([batch, o], dim=0)
bs += 1
continue
if 'path' not in k and 'center' not in k: if 'path' not in k and 'center' not in k:
fr, f, h, w = v.shape b, fr, f, h, w = batch.shape
for j in range(fr): for j in range(fr):
import torchvision import torchvision
base=osp.basename(o["GT_path"]) base=osp.basename(base_file)
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base)) torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
bs = 0
batch = None