From 7e777ea34cfe86cb8c6daf8f294f61601ebe9294 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 9 Oct 2020 19:21:43 -0600 Subject: [PATCH] Allow tecogan to be used in process_video --- .../use_discriminator_as_filter.py | 19 ++++++------ codes/models/ExtensibleTrainer.py | 5 ++-- codes/models/steps/tecogan_losses.py | 19 ++++++++++++ codes/process_video.py | 30 +++++++++++++++++-- 4 files changed, 59 insertions(+), 14 deletions(-) diff --git a/codes/data_scripts/use_discriminator_as_filter.py b/codes/data_scripts/use_discriminator_as_filter.py index 078641a8..66443772 100644 --- a/codes/data_scripts/use_discriminator_as_filter.py +++ b/codes/data_scripts/use_discriminator_as_filter.py @@ -61,19 +61,20 @@ if __name__ == "__main__": tq = tqdm(test_loader) removed = 0 + means = [] + dataset_mean = -7.133 for data in tq: model.feed_data(data, need_GT=True) model.test() results = model.eval_state['discriminator_out'][0] - print(torch.mean(results), torch.max(results), torch.min(results)) + means.append(torch.mean(results).item()) + print(sum(means)/len(means), torch.mean(results), torch.max(results), torch.min(results)) for i in range(results.shape[0]): - if results[i] < .8: - os.remove(data['GT_path'][i]) - removed += 1 - #imname = osp.basename(data['GT_path'][i]) - #if results[i] > .8: - # torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname)) - #else: - # torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) + #if results[i] < .8: + # os.remove(data['GT_path'][i]) + # removed += 1 + imname = osp.basename(data['GT_path'][i]) + if results[i]-dataset_mean > 1: + torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 88e34671..51154fd0 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -30,8 +30,9 @@ class ExtensibleTrainer(BaseModel): self.env = {'device': self.device, 'rank': self.rank, 'opt': opt, - 'step': 0, - 'base_path': os.path.join(opt['path']['models'])} + 'step': 0} + if opt['path']['models'] is not None: + self.env['base_path'] = os.path.join(opt['path']['models']) self.mega_batch_factor = 1 if self.is_train: diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 1c033c60..5565b749 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -21,6 +21,8 @@ def create_teco_injector(opt, env): type = opt['type'] if type == 'teco_recurrent_generated_sequence_injector': return RecurrentImageGeneratorSequenceInjector(opt, env) + elif type == 'teco_flow_adjustment': + return FlowAdjustment(opt, env) return None def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin): @@ -132,6 +134,23 @@ class RecurrentImageGeneratorSequenceInjector(Injector): torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,))) +class FlowAdjustment(Injector): + def __init__(self, opt, env): + super(FlowAdjustment, self).__init__(opt, env) + self.resample = Resample2d() + self.flow = opt['flow_network'] + self.flow_target = opt['flow_target'] + self.flowed = opt['flowed'] + + def forward(self, state): + flow = self.env['generators'][self.flow] + flow_target = state[self.flow_target] + flowed = state[self.flowed] + flow_input = torch.stack([flow_target, flowed], dim=2) + flowfield = flow(flow_input) + return {self.output: self.resample(flowed.float(), flowfield.float())} + + # This is the temporal discriminator loss from TecoGAN. # # It has a strict contract for 'real' and 'fake' inputs: diff --git a/codes/process_video.py b/codes/process_video.py index bc91d317..a8b411d4 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -28,6 +28,7 @@ class FfmpegBackedVideoDataset(data.Dataset): self.frame_rate = self.opt['frame_rate'] self.start_at = self.opt['start_at_seconds'] self.end_at = self.opt['end_at_seconds'] + self.force_multiple = self.opt['force_multiple'] self.frame_count = (self.end_at - self.start_at) * self.frame_rate # The number of (original) video frames that will be stored on the filesystem at a time. self.max_working_files = 20 @@ -69,6 +70,18 @@ class FfmpegBackedVideoDataset(data.Dataset): mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2]) ref = torch.cat([img_LQ, mask], dim=0) + + if self.force_multiple > 1: + assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now. + _, h, w = img_LQ.shape + height_removed = h % self.force_multiple + width_removed = w % self.force_multiple + if height_removed != 0: + img_LQ = img_LQ[:, :-height_removed, :] + ref = ref[:, :-height_removed, :] + if width_removed != 0: + img_LQ = img_LQ[:, :, :-width_removed] + ref = ref[:, :, :-width_removed] return {'LQ': img_LQ, 'lq_fullsize_ref': ref, 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } @@ -128,18 +141,30 @@ if __name__ == "__main__": vid_output = opt['mini_vid_output_folder'] if 'mini_vid_output_folder' in opt.keys() else dataset_dir vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0 img_index = opt['generator_img_index'] + recurrent_mode = opt['recurrent_mode'] + first_frame = True ffmpeg_proc = None tq = tqdm(test_loader) for data in tq: need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True + + if recurrent_mode and first_frame: + recurrent_entry = data['LQ'].detach().clone() + first_frame = False + if recurrent_mode: + data['recurrent'] = recurrent_entry + model.feed_data(data, need_GT=need_GT) model.test() if isinstance(model.fake_H, tuple): - visuals = model.fake_H[img_index].detach().float().cpu() + visuals = model.fake_H[img_index].detach() else: - visuals = model.fake_H.detach().float().cpu() + visuals = model.fake_H.detach() + if recurrent_mode: + recurrent_entry = torch.nn.functional.interpolate(visuals, scale_factor=1/opt['scale'], mode='bicubic') + visuals = visuals.cpu().float() for i in range(visuals.shape[0]): sr_img = util.tensor2img(visuals[i]) # uint8 @@ -148,7 +173,6 @@ if __name__ == "__main__": util.save_img(sr_img, save_img_path) frame_counter += 1 - if frame_counter % frames_per_vid == 0: if ffmpeg_proc is not None: print("Waiting for last encode..")