forked from mrq/DL-Art-School
Misc nonfunctional mods to datasets
This commit is contained in:
parent
43c4f92123
commit
f9dc472f63
|
@ -31,6 +31,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
if os.path.exists(cache_path):
|
if os.path.exists(cache_path):
|
||||||
chunks = torch.load(cache_path)
|
chunks = torch.load(cache_path)
|
||||||
else:
|
else:
|
||||||
|
print("Building chunk cache, this can take some time for large datasets..")
|
||||||
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
||||||
# Prune out chunks that have no images
|
# Prune out chunks that have no images
|
||||||
res = []
|
res = []
|
||||||
|
|
|
@ -55,14 +55,14 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
||||||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
||||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||||
|
|
||||||
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
opt = {
|
opt = {
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['/content/fullvideo_256_tiled_test'],
|
'paths': ['F:\\4k6k\\datasets\\images\\fullvideo\\ge_fv_256_tiled'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 128,
|
'target_size': 128,
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
|
@ -77,14 +77,28 @@ if __name__ == '__main__':
|
||||||
ds = MultiFrameDataset(opt)
|
ds = MultiFrameDataset(opt)
|
||||||
import os
|
import os
|
||||||
os.makedirs("debug", exist_ok=True)
|
os.makedirs("debug", exist_ok=True)
|
||||||
for i in [3]:
|
bs = 0
|
||||||
|
batch = None
|
||||||
|
for i in range(len(ds)):
|
||||||
import random
|
import random
|
||||||
o = ds[i]
|
k = 'LQ'
|
||||||
k = 'GT'
|
element = ds[random.randint(0,len(ds))]
|
||||||
v = o[k]
|
base_file = osp.basename(element["GT_path"])
|
||||||
|
o = element[k].unsqueeze(0)
|
||||||
|
if bs < 32:
|
||||||
|
if batch is None:
|
||||||
|
batch = o
|
||||||
|
else:
|
||||||
|
batch = torch.cat([batch, o], dim=0)
|
||||||
|
bs += 1
|
||||||
|
continue
|
||||||
|
|
||||||
if 'path' not in k and 'center' not in k:
|
if 'path' not in k and 'center' not in k:
|
||||||
fr, f, h, w = v.shape
|
b, fr, f, h, w = batch.shape
|
||||||
for j in range(fr):
|
for j in range(fr):
|
||||||
import torchvision
|
import torchvision
|
||||||
base=osp.basename(o["GT_path"])
|
base=osp.basename(base_file)
|
||||||
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||||
|
|
||||||
|
bs = 0
|
||||||
|
batch = None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user