import argparse import logging import os import os.path as osp import subprocess import time import torch import as data import torchvision.transforms.functional as F from PIL import Image from tqdm import tqdm from trainer.ExtensibleTrainer import ExtensibleTrainer from utils import options as option import utils.util as util from data import create_dataloader class FfmpegBackedVideoDataset(data.Dataset): '''Pulls frames from a video one at a time using FFMPEG.''' def __init__(self, opt, working_dir): super(FfmpegBackedVideoDataset, self).__init__() self.opt = opt = self.opt['video_file'] self.working_dir = working_dir self.frame_rate = self.opt['frame_rate'] self.start_at = self.opt['start_at_seconds'] self.end_at = self.opt['end_at_seconds'] self.force_multiple = self.opt['force_multiple'] self.frame_count = (self.end_at - self.start_at) * self.frame_rate # The number of (original) video frames that will be stored on the filesystem at a time. self.max_working_files = 20 self.data_type = self.opt['data_type'] self.vertical_splits = self.opt['vertical_splits'] if 'vertical_splits' in opt.keys() else 1 def get_time_for_it(self, it): secs = it / self.frame_rate + self.start_at mins = int(secs / 60) hours = int(mins / 60) secs = secs - (mins * 60) - (hours * 3600) mins = mins % 60 return '%02d:%02d:%06.3f' % (hours, mins, secs) def __getitem__(self, index): if self.vertical_splits > 0: actual_index = int(index / self.vertical_splits) else: actual_index = index # Extract the frame. Command template: `ffmpeg -ss 17:00.0323 -i <video file>.mp4 -vframes 1 destination.png` working_file_name = osp.join(self.working_dir, "working_%d.png" % (actual_index % self.max_working_files,)) vid_time = self.get_time_for_it(actual_index) ffmpeg_args = ['ffmpeg', '-y', '-ss', vid_time, '-i',, '-vframes', '1', working_file_name] process = subprocess.Popen(ffmpeg_args, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) process.wait() # get LQ image LQ_path = working_file_name img_LQ = split_index = (index % self.vertical_splits) if self.vertical_splits > 0: w, h = img_LQ.size w_per_split = int(w / self.vertical_splits) left = w_per_split * split_index img_LQ = F.crop(img_LQ, 0, left, h, w_per_split) img_LQ = F.to_tensor(img_LQ) mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2]) ref =[img_LQ, mask], dim=0) if self.force_multiple > 1: assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now. c, h, w = img_LQ.shape h_, w_ = h, w height_removed = h % self.force_multiple width_removed = w % self.force_multiple if height_removed != 0: h_ = self.force_multiple * ((h // self.force_multiple) + 1) if width_removed != 0: w_ = self.force_multiple * ((w // self.force_multiple) + 1) lq_template = torch.zeros(c,h_,w_) lq_template[:,:h,:w] = img_LQ ref_template = torch.zeros(c,h_,w_) ref_template[:,:h,:w] = img_LQ img_LQ = lq_template ref = ref_template return {'lq': img_LQ, 'lq_fullsize_ref': ref, 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } def __len__(self): return self.frame_count * self.vertical_splits def merge_images(files, output_path): """Merges several image files together across the vertical axis """ images = [ for f in files] w, h = images[0].size result_width = w * len(images) result_height = h result ='RGB', (result_width, result_height)) for i in range(len(images)): result.paste(im=images[i], box=(i * w, 0)) if __name__ == "__main__": #### options torch.backends.cudnn.benchmark = True want_just_images = True parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') util.loaded_options = opt #### Create test dataset and dataloader test_loaders = [] test_set = FfmpegBackedVideoDataset(opt['dataset'], opt['path']['results_root']) test_loader = create_dataloader(test_set, opt['dataset'])'Number of test images in [{:s}]: {:d}'.format(opt['dataset']['name'], len(test_set))) test_loaders.append(test_loader) model = ExtensibleTrainer(opt) test_set_name = test_loader.dataset.opt['name']'\nTesting [{:s}]...'.format(test_set_name)) test_start_time = time.time() dataset_dir = osp.join(opt['path']['results_root'], test_set_name) util.mkdir(dataset_dir) frame_counter = 0 frames_per_vid = opt['frames_per_mini_vid'] minivid_crf = opt['minivid_crf'] vid_output = opt['mini_vid_output_folder'] if 'mini_vid_output_folder' in opt.keys() else dataset_dir vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0 img_index = opt['generator_img_index'] recurrent_mode = opt['recurrent_mode'] if recurrent_mode: assert opt['dataset']['batch_size'] == 1 # Can only do 1 frame at a time in recurrent mode, by definition. scale = opt['scale'] first_frame = True ffmpeg_proc = None tq = tqdm(test_loader) for data in tq: need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True if recurrent_mode and first_frame: b, c, h, w = data['lq'].shape recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device) # Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of. if 'recurrent_hr_generator' in opt.keys(): recurrent_gen = model.env['generators']['generator'] model.env['generators']['generator'] = model.env['generators'][opt['recurrent_hr_generator']] first_frame = False if recurrent_mode: data['recurrent'] = recurrent_entry model.feed_data(data, 0, need_GT=need_GT) model.test() visuals = model.get_current_visuals()['rlt'] if recurrent_mode: recurrent_entry = visuals visuals = visuals.cpu().float() for i in range(visuals.shape[0]): sr_img = util.tensor2img(visuals[i]) # uint8 # save images save_img_path = osp.join(dataset_dir, '%08d.png' % (frame_counter,)) util.save_img(sr_img, save_img_path) frame_counter += 1 if frame_counter % frames_per_vid == 0: if ffmpeg_proc is not None: print("Waiting for last encode..") ffmpeg_proc.wait() print("Encoding minivid %d.." % (vid_counter,)) # Perform stitching. num_splits = opt['dataset']['vertical_splits'] if 'vertical_splits' in opt['dataset'].keys() else 1 if num_splits > 1: procs = [] src_imgs_path = osp.join(dataset_dir, "joined") os.makedirs(src_imgs_path, exist_ok=True) for i in range(int(frames_per_vid / num_splits)): to_join = [osp.join(dataset_dir, "%08d.png" % (j,)) for j in range(i * num_splits, i * num_splits + num_splits)] merge_images(to_join, osp.join(src_imgs_path, "%08d.png" % (i,))) else: src_imgs_path = dataset_dir # Encoding command line: # ffmpeg -framerate 30 -i %08d.png -c:v libx265 -crf 12 -preset slow -pix_fmt yuv444p test.mkv cmd = ['ffmpeg', '-y', '-framerate', str(opt['dataset']['frame_rate']), '-f', 'image2', '-i', osp.join(src_imgs_path, "%08d.png"), '-c:v', 'libx265', '-crf', str(minivid_crf), '-preset', 'slow', '-pix_fmt', 'yuv444p', osp.join(vid_output, "mini_%06d.mkv" % (vid_counter,))] print(ffmpeg_proc) ffmpeg_proc = subprocess.Popen(cmd)#, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) vid_counter += 1 frame_counter = 0 print("Done.") if want_just_images: continue