Mods to support video processing with teco networks
This commit is contained in:
parent
17d78195ee
commit
e620fc05ba
|
@ -160,10 +160,10 @@ class FlowAdjustment(Injector):
|
|||
def forward(self, state):
|
||||
flow = self.env['generators'][self.flow]
|
||||
flow_target = state[self.flow_target]
|
||||
flowed = state[self.flowed]
|
||||
flowed = F.interpolate(state[self.flowed], size=flow_target.shape[2:], mode='bicubic')
|
||||
flow_input = torch.stack([flow_target, flowed], dim=2)
|
||||
flowfield = flow(flow_input)
|
||||
return {self.output: self.resample(flowed.float(), flowfield.float())}
|
||||
flowfield = F.interpolate(flow(flow_input), size=state[self.flowed].shape[2:], mode='bicubic')
|
||||
return {self.output: self.resample(state[self.flowed].float(), flowfield.float())}
|
||||
|
||||
|
||||
# This is the temporal discriminator loss from TecoGAN.
|
||||
|
|
|
@ -142,6 +142,9 @@ if __name__ == "__main__":
|
|||
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
|
||||
img_index = opt['generator_img_index']
|
||||
recurrent_mode = opt['recurrent_mode']
|
||||
if recurrent_mode:
|
||||
assert opt['dataset']['batch_size'] == 1 # Can only do 1 frame at a time in recurrent mode, by definition.
|
||||
scale = opt['scale']
|
||||
first_frame = True
|
||||
ffmpeg_proc = None
|
||||
|
||||
|
@ -150,7 +153,8 @@ if __name__ == "__main__":
|
|||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
|
||||
if recurrent_mode and first_frame:
|
||||
recurrent_entry = data['LQ'].detach().clone()
|
||||
b, c, h, w = data['LQ'].shape
|
||||
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device)
|
||||
first_frame = False
|
||||
if recurrent_mode:
|
||||
data['recurrent'] = recurrent_entry
|
||||
|
@ -163,7 +167,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
visuals = model.fake_H.detach()
|
||||
if recurrent_mode:
|
||||
recurrent_entry = torch.nn.functional.interpolate(visuals, scale_factor=1/opt['scale'], mode='bicubic')
|
||||
recurrent_entry = visuals
|
||||
visuals = visuals.cpu().float()
|
||||
for i in range(visuals.shape[0]):
|
||||
sr_img = util.tensor2img(visuals[i]) # uint8
|
||||
|
|
Loading…
Reference in New Issue
Block a user