diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index c95e214b..c1a57e3a 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -160,10 +160,10 @@ class FlowAdjustment(Injector): def forward(self, state): flow = self.env['generators'][self.flow] flow_target = state[self.flow_target] - flowed = state[self.flowed] + flowed = F.interpolate(state[self.flowed], size=flow_target.shape[2:], mode='bicubic') flow_input = torch.stack([flow_target, flowed], dim=2) - flowfield = flow(flow_input) - return {self.output: self.resample(flowed.float(), flowfield.float())} + flowfield = F.interpolate(flow(flow_input), size=state[self.flowed].shape[2:], mode='bicubic') + return {self.output: self.resample(state[self.flowed].float(), flowfield.float())} # This is the temporal discriminator loss from TecoGAN. diff --git a/codes/process_video.py b/codes/process_video.py index a8b411d4..96ec23bf 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -142,6 +142,9 @@ if __name__ == "__main__": vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0 img_index = opt['generator_img_index'] recurrent_mode = opt['recurrent_mode'] + if recurrent_mode: + assert opt['dataset']['batch_size'] == 1 # Can only do 1 frame at a time in recurrent mode, by definition. + scale = opt['scale'] first_frame = True ffmpeg_proc = None @@ -150,7 +153,8 @@ if __name__ == "__main__": need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True if recurrent_mode and first_frame: - recurrent_entry = data['LQ'].detach().clone() + b, c, h, w = data['LQ'].shape + recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device) first_frame = False if recurrent_mode: data['recurrent'] = recurrent_entry @@ -163,7 +167,7 @@ if __name__ == "__main__": else: visuals = model.fake_H.detach() if recurrent_mode: - recurrent_entry = torch.nn.functional.interpolate(visuals, scale_factor=1/opt['scale'], mode='bicubic') + recurrent_entry = visuals visuals = visuals.cpu().float() for i in range(visuals.shape[0]): sr_img = util.tensor2img(visuals[i]) # uint8