From 543d459b4efc3b410c197fdc7a1d026052d3765b Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 8 Feb 2021 08:10:24 -0700 Subject: [PATCH] extract_temporal_squares script For extracting related patches across a video --- codes/scripts/extract_temporal_squares.py | 201 ++++++++++++++++++++++ 1 file changed, 201 insertions(+) diff --git a/codes/scripts/extract_temporal_squares.py b/codes/scripts/extract_temporal_squares.py index e69de29b..bb74f7ed 100644 --- a/codes/scripts/extract_temporal_squares.py +++ b/codes/scripts/extract_temporal_squares.py @@ -0,0 +1,201 @@ +"""A multi-thread tool to crop large images to sub-images for faster IO.""" +import os +import os.path as osp +import shutil +import subprocess +from time import sleep + +import munch +import numpy as np +import cv2 +import torchvision +from PIL import Image +import data.util as data_util # noqa: E402 +import torch.utils.data as data +from tqdm import tqdm +import torch +import random + +from models.flownet2.networks.resample2d_package.resample2d import Resample2d +from models.optical_flow.PWCNet import pwc_dc_net + + +def main(): + opt = {} + opt['n_thread'] = 0 + opt['compression_level'] = 95 # JPEG compression quality rating. + opt['dest'] = 'file' + opt['input_folder'] = 'D:\\dlas\\codes\\scripts\\test' + opt['save_folder'] = 'D:\\dlas\\codes\\scripts\\test_out' + opt['imgsize'] = 256 + opt['bottom_crop'] = .1 + opt['keep_folder'] = False + + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print('mkdir [{:s}] ...'.format(save_folder)) + + go(opt) + + +def is_video(filename): + return any(filename.endswith(extension) for extension in ['.mp4', '.MP4', '.avi', '.AVI', '.mkv', '.MKV', '.wmv', '.WMV']) + + +def get_videos_in_path(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + videos = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_video(fname): + videos.append(os.path.join(dirpath, fname)) + return videos + +def get_time_for_secs(secs): + mins = int(secs / 60) + hours = int(mins / 60) + secs = secs - (mins * 60) - (hours * 3600) + mins = mins % 60 + return '%02d:%02d:%06.3f' % (hours, mins, secs) + +class VideoClipDataset(data.Dataset): + def __init__(self, opt): + self.opt = opt + input_folder = opt['input_folder'] + self.videos = get_videos_in_path(input_folder) + print("Found %i videos" % (len(self.videos),)) + + def __getitem__(self, index): + return self.get(index) + + def extract_n_frames(self, video_file, dest, time_seconds, n): + ffmpeg_args = ['ffmpeg', '-y', '-ss', get_time_for_secs(time_seconds), '-i', video_file, '-vframes', str(n), f'{dest}/%d.jpg'] + process = subprocess.Popen(ffmpeg_args, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) + process.wait() + + def get_video_length(self, video_file): + result = subprocess.run(["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", + "default=noprint_wrappers=1", video_file], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + return float(result.stdout.decode('utf-8').strip().replace("duration=", "")) + + def get_image_tensor(self, path): + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + + # Access exceptions happen regularly, probably due to the subprocess not fully terminating. + for tries in range(5): + try: + os.remove(path) + break + except: + if tries >= 4: + assert False + else: + sleep(.1) + + assert img is not None + assert len(img.shape) > 2 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. + # Crop off excess so image dimensions are a multiple of 64. + h, w, _ = img.shape + img = img[:(h//64)*64,:(w//64)*64,:] + return torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float() + + def get(self, index): + path = self.videos[index] + out_folder = self.opt['save_folder'] + + vid_len = int(self.get_video_length(path)) + start = 2 + img_runs = [] + while start < vid_len: + frames_out = os.path.join(out_folder, f'{index}_{start}') + os.makedirs(frames_out, exist_ok=False) + n = random.randint(5, 30) + self.extract_n_frames(path, frames_out, start, n) + frames = data_util.get_image_paths('img', frames_out)[0] + assert len(frames) == n + img_runs.append(([self.get_image_tensor(frame) for frame in frames], frames_out)) + start += random.randint(2,5) + + return img_runs + + def __len__(self): + return len(self.videos) + + +def compute_flow_and_cleanup(flownet, runs): + resampler = Resample2d().cuda() + for run in runs: + run, path = run + consolidated_flows = None + a = run[0].unsqueeze(0).cuda() + img = a + dbg = a.clone() + for i in range(1,len(run)): + img2 = run[i].unsqueeze(0).cuda() + flow = flownet(torch.cat([img2, img], dim=1)) + flow = torch.nn.functional.interpolate(flow, size=img.shape[2:], mode='bilinear') + if consolidated_flows is None: + consolidated_flows = flow + else: + consolidated_flows = resampler(flow, -consolidated_flows) + consolidated_flows + img = img2 + dbg = resampler(dbg, flow) + torchvision.utils.save_image(dbg, os.path.join(path, "debug.jpg")) + consolidated_flows = torch.clamp(consolidated_flows / 255, -.5, .5) + b = run[-1].unsqueeze(0).cuda() + _, _, h, w = a.shape + direct_flows = torch.nn.functional.interpolate(torch.clamp(flownet(torch.cat([a, b], dim=1).float()) / 255, -.5, .5), size=img.shape[2:], mode='bilinear') + # TODO: Reshape image here. + ''' + # Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to + # consider these areas of the image. + if 'bottom_crop' in self.opt.keys() and self.opt['bottom_crop'] > 0: + bc = self.opt['bottom_crop'] + if bc > 0 and bc < 1: + bc = int(bc * img.shape[0]) + img = img[:-bc, :, :] + + h, w, c = img.shape + assert min(h,w) >= self.opt['imgsize'] + + # We must convert the image into a square. + dim = min(h, w) + # Crop the image so that only the center is left, since this is often the most salient part of the image. + img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] + img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA) + ''' + + torchvision.utils.save_image(a, os.path.join(path, "a.jpg")) + torchvision.utils.save_image(b, os.path.join(path, "b.jpg")) + torch.save(consolidated_flows * 255, os.path.join(path, "consolidated_flow.pt")) + torchvision.utils.save_image(torch.cat([consolidated_flows + .5, torch.zeros((1, 1, h, w), device='cuda')], dim=1), os.path.join(path, "consolidated_flow.png")) + + # For debugging + torchvision.utils.save_image(resampler(a, consolidated_flows * 255), os.path.join(path, "b_flowed.jpg")) + torchvision.utils.save_image(resampler(b, -consolidated_flows * 255), os.path.join(path, "a_flowed.jpg")) + torchvision.utils.save_image(resampler(b, direct_flows * 255), os.path.join(path, "a_flowed_nonconsolidated.jpg")) + torchvision.utils.save_image(torch.cat([direct_flows + .5, torch.zeros((1, 1, h, w), device='cuda')], dim=1), os.path.join(path, "direct_flow.png")) + + +def identity(x): + return x + + +def go(opt): + flownet = pwc_dc_net('../experiments/pwc_humanflow.pth') + flownet.eval() + flownet = flownet.cuda() + + dataset = VideoClipDataset(opt) + dataloader = data.DataLoader(dataset, num_workers=opt['n_thread'], collate_fn=identity) + with torch.no_grad(): + for batch in tqdm(dataloader): + compute_flow_and_cleanup(flownet, batch[0]) + + +if __name__ == '__main__': + main()