Fix multi-frame dataset OBO error

This commit is contained in:
James Betker 2020-10-08 12:21:04 -06:00
parent fba29d7dcc
commit b36ba0460c
2 changed files with 7 additions and 5 deletions

View File

@ -34,7 +34,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
for c in chunks: for c in chunks:
c.reload(opt) c.reload(opt)
else: else:
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
# Prune out chunks that have no images # Prune out chunks that have no images
res = [] res = []
for c in chunks: for c in chunks:

View File

@ -23,6 +23,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
frames_needed -= 1 frames_needed -= 1
search_idx -= 1 search_idx -= 1
else: else:
search_idx += 1
break break
# Now build num_frames starting from search_idx. # Now build num_frames starting from search_idx.
@ -62,7 +63,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_256_tiled_with_ref'], 'paths': ['/content/fullvideo_256_tiled_test'],
'weights': [1], 'weights': [1],
'target_size': 128, 'target_size': 128,
'force_multiple': 32, 'force_multiple': 32,
@ -77,13 +78,14 @@ if __name__ == '__main__':
ds = MultiFrameDataset(opt) ds = MultiFrameDataset(opt)
import os import os
os.makedirs("debug", exist_ok=True) os.makedirs("debug", exist_ok=True)
for i in range(100000, len(ds)): for i in [3]:
import random import random
o = ds[random.randint(0, 1000000)] o = ds[i]
k = 'GT' k = 'GT'
v = o[k] v = o[k]
if 'path' not in k and 'center' not in k: if 'path' not in k and 'center' not in k:
fr, f, h, w = v.shape fr, f, h, w = v.shape
for j in range(fr): for j in range(fr):
import torchvision import torchvision
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j)) base=osp.basename(o["GT_path"])
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base))