Mods to support video processing with teco networks

This commit is contained in:
James Betker 2020-10-13 20:47:05 -06:00
parent 17d78195ee
commit e620fc05ba
2 changed files with 9 additions and 5 deletions

View File

@ -160,10 +160,10 @@ class FlowAdjustment(Injector):
def forward(self, state): def forward(self, state):
flow = self.env['generators'][self.flow] flow = self.env['generators'][self.flow]
flow_target = state[self.flow_target] flow_target = state[self.flow_target]
flowed = state[self.flowed] flowed = F.interpolate(state[self.flowed], size=flow_target.shape[2:], mode='bicubic')
flow_input = torch.stack([flow_target, flowed], dim=2) flow_input = torch.stack([flow_target, flowed], dim=2)
flowfield = flow(flow_input) flowfield = F.interpolate(flow(flow_input), size=state[self.flowed].shape[2:], mode='bicubic')
return {self.output: self.resample(flowed.float(), flowfield.float())} return {self.output: self.resample(state[self.flowed].float(), flowfield.float())}
# This is the temporal discriminator loss from TecoGAN. # This is the temporal discriminator loss from TecoGAN.

View File

@ -142,6 +142,9 @@ if __name__ == "__main__":
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0 vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
img_index = opt['generator_img_index'] img_index = opt['generator_img_index']
recurrent_mode = opt['recurrent_mode'] recurrent_mode = opt['recurrent_mode']
if recurrent_mode:
assert opt['dataset']['batch_size'] == 1 # Can only do 1 frame at a time in recurrent mode, by definition.
scale = opt['scale']
first_frame = True first_frame = True
ffmpeg_proc = None ffmpeg_proc = None
@ -150,7 +153,8 @@ if __name__ == "__main__":
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
if recurrent_mode and first_frame: if recurrent_mode and first_frame:
recurrent_entry = data['LQ'].detach().clone() b, c, h, w = data['LQ'].shape
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device)
first_frame = False first_frame = False
if recurrent_mode: if recurrent_mode:
data['recurrent'] = recurrent_entry data['recurrent'] = recurrent_entry
@ -163,7 +167,7 @@ if __name__ == "__main__":
else: else:
visuals = model.fake_H.detach() visuals = model.fake_H.detach()
if recurrent_mode: if recurrent_mode:
recurrent_entry = torch.nn.functional.interpolate(visuals, scale_factor=1/opt['scale'], mode='bicubic') recurrent_entry = visuals
visuals = visuals.cpu().float() visuals = visuals.cpu().float()
for i in range(visuals.shape[0]): for i in range(visuals.shape[0]):
sr_img = util.tensor2img(visuals[i]) # uint8 sr_img = util.tensor2img(visuals[i]) # uint8