Allow tecogan to be used in process_video
This commit is contained in:
parent
58d8bf8f69
commit
7e777ea34c
|
@ -61,19 +61,20 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
tq = tqdm(test_loader)
|
tq = tqdm(test_loader)
|
||||||
removed = 0
|
removed = 0
|
||||||
|
means = []
|
||||||
|
dataset_mean = -7.133
|
||||||
for data in tq:
|
for data in tq:
|
||||||
model.feed_data(data, need_GT=True)
|
model.feed_data(data, need_GT=True)
|
||||||
model.test()
|
model.test()
|
||||||
results = model.eval_state['discriminator_out'][0]
|
results = model.eval_state['discriminator_out'][0]
|
||||||
print(torch.mean(results), torch.max(results), torch.min(results))
|
means.append(torch.mean(results).item())
|
||||||
|
print(sum(means)/len(means), torch.mean(results), torch.max(results), torch.min(results))
|
||||||
for i in range(results.shape[0]):
|
for i in range(results.shape[0]):
|
||||||
if results[i] < .8:
|
#if results[i] < .8:
|
||||||
os.remove(data['GT_path'][i])
|
# os.remove(data['GT_path'][i])
|
||||||
removed += 1
|
# removed += 1
|
||||||
#imname = osp.basename(data['GT_path'][i])
|
imname = osp.basename(data['GT_path'][i])
|
||||||
#if results[i] > .8:
|
if results[i]-dataset_mean > 1:
|
||||||
# torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname))
|
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
|
||||||
#else:
|
|
||||||
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
|
|
||||||
|
|
||||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
print("Removed %i/%i images" % (removed, len(test_set)))
|
|
@ -30,8 +30,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.env = {'device': self.device,
|
self.env = {'device': self.device,
|
||||||
'rank': self.rank,
|
'rank': self.rank,
|
||||||
'opt': opt,
|
'opt': opt,
|
||||||
'step': 0,
|
'step': 0}
|
||||||
'base_path': os.path.join(opt['path']['models'])}
|
if opt['path']['models'] is not None:
|
||||||
|
self.env['base_path'] = os.path.join(opt['path']['models'])
|
||||||
|
|
||||||
self.mega_batch_factor = 1
|
self.mega_batch_factor = 1
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
|
|
@ -21,6 +21,8 @@ def create_teco_injector(opt, env):
|
||||||
type = opt['type']
|
type = opt['type']
|
||||||
if type == 'teco_recurrent_generated_sequence_injector':
|
if type == 'teco_recurrent_generated_sequence_injector':
|
||||||
return RecurrentImageGeneratorSequenceInjector(opt, env)
|
return RecurrentImageGeneratorSequenceInjector(opt, env)
|
||||||
|
elif type == 'teco_flow_adjustment':
|
||||||
|
return FlowAdjustment(opt, env)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
||||||
|
@ -132,6 +134,23 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,)))
|
torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||||
|
|
||||||
|
|
||||||
|
class FlowAdjustment(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(FlowAdjustment, self).__init__(opt, env)
|
||||||
|
self.resample = Resample2d()
|
||||||
|
self.flow = opt['flow_network']
|
||||||
|
self.flow_target = opt['flow_target']
|
||||||
|
self.flowed = opt['flowed']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
flow = self.env['generators'][self.flow]
|
||||||
|
flow_target = state[self.flow_target]
|
||||||
|
flowed = state[self.flowed]
|
||||||
|
flow_input = torch.stack([flow_target, flowed], dim=2)
|
||||||
|
flowfield = flow(flow_input)
|
||||||
|
return {self.output: self.resample(flowed.float(), flowfield.float())}
|
||||||
|
|
||||||
|
|
||||||
# This is the temporal discriminator loss from TecoGAN.
|
# This is the temporal discriminator loss from TecoGAN.
|
||||||
#
|
#
|
||||||
# It has a strict contract for 'real' and 'fake' inputs:
|
# It has a strict contract for 'real' and 'fake' inputs:
|
||||||
|
|
|
@ -28,6 +28,7 @@ class FfmpegBackedVideoDataset(data.Dataset):
|
||||||
self.frame_rate = self.opt['frame_rate']
|
self.frame_rate = self.opt['frame_rate']
|
||||||
self.start_at = self.opt['start_at_seconds']
|
self.start_at = self.opt['start_at_seconds']
|
||||||
self.end_at = self.opt['end_at_seconds']
|
self.end_at = self.opt['end_at_seconds']
|
||||||
|
self.force_multiple = self.opt['force_multiple']
|
||||||
self.frame_count = (self.end_at - self.start_at) * self.frame_rate
|
self.frame_count = (self.end_at - self.start_at) * self.frame_rate
|
||||||
# The number of (original) video frames that will be stored on the filesystem at a time.
|
# The number of (original) video frames that will be stored on the filesystem at a time.
|
||||||
self.max_working_files = 20
|
self.max_working_files = 20
|
||||||
|
@ -69,6 +70,18 @@ class FfmpegBackedVideoDataset(data.Dataset):
|
||||||
|
|
||||||
mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2])
|
mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2])
|
||||||
ref = torch.cat([img_LQ, mask], dim=0)
|
ref = torch.cat([img_LQ, mask], dim=0)
|
||||||
|
|
||||||
|
if self.force_multiple > 1:
|
||||||
|
assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now.
|
||||||
|
_, h, w = img_LQ.shape
|
||||||
|
height_removed = h % self.force_multiple
|
||||||
|
width_removed = w % self.force_multiple
|
||||||
|
if height_removed != 0:
|
||||||
|
img_LQ = img_LQ[:, :-height_removed, :]
|
||||||
|
ref = ref[:, :-height_removed, :]
|
||||||
|
if width_removed != 0:
|
||||||
|
img_LQ = img_LQ[:, :, :-width_removed]
|
||||||
|
ref = ref[:, :, :-width_removed]
|
||||||
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
||||||
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
||||||
|
|
||||||
|
@ -128,18 +141,30 @@ if __name__ == "__main__":
|
||||||
vid_output = opt['mini_vid_output_folder'] if 'mini_vid_output_folder' in opt.keys() else dataset_dir
|
vid_output = opt['mini_vid_output_folder'] if 'mini_vid_output_folder' in opt.keys() else dataset_dir
|
||||||
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
|
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
|
||||||
img_index = opt['generator_img_index']
|
img_index = opt['generator_img_index']
|
||||||
|
recurrent_mode = opt['recurrent_mode']
|
||||||
|
first_frame = True
|
||||||
ffmpeg_proc = None
|
ffmpeg_proc = None
|
||||||
|
|
||||||
tq = tqdm(test_loader)
|
tq = tqdm(test_loader)
|
||||||
for data in tq:
|
for data in tq:
|
||||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||||
|
|
||||||
|
if recurrent_mode and first_frame:
|
||||||
|
recurrent_entry = data['LQ'].detach().clone()
|
||||||
|
first_frame = False
|
||||||
|
if recurrent_mode:
|
||||||
|
data['recurrent'] = recurrent_entry
|
||||||
|
|
||||||
model.feed_data(data, need_GT=need_GT)
|
model.feed_data(data, need_GT=need_GT)
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
if isinstance(model.fake_H, tuple):
|
if isinstance(model.fake_H, tuple):
|
||||||
visuals = model.fake_H[img_index].detach().float().cpu()
|
visuals = model.fake_H[img_index].detach()
|
||||||
else:
|
else:
|
||||||
visuals = model.fake_H.detach().float().cpu()
|
visuals = model.fake_H.detach()
|
||||||
|
if recurrent_mode:
|
||||||
|
recurrent_entry = torch.nn.functional.interpolate(visuals, scale_factor=1/opt['scale'], mode='bicubic')
|
||||||
|
visuals = visuals.cpu().float()
|
||||||
for i in range(visuals.shape[0]):
|
for i in range(visuals.shape[0]):
|
||||||
sr_img = util.tensor2img(visuals[i]) # uint8
|
sr_img = util.tensor2img(visuals[i]) # uint8
|
||||||
|
|
||||||
|
@ -148,7 +173,6 @@ if __name__ == "__main__":
|
||||||
util.save_img(sr_img, save_img_path)
|
util.save_img(sr_img, save_img_path)
|
||||||
frame_counter += 1
|
frame_counter += 1
|
||||||
|
|
||||||
|
|
||||||
if frame_counter % frames_per_vid == 0:
|
if frame_counter % frames_per_vid == 0:
|
||||||
if ffmpeg_proc is not None:
|
if ffmpeg_proc is not None:
|
||||||
print("Waiting for last encode..")
|
print("Waiting for last encode..")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user