forked from mrq/DL-Art-School
Mods to support video processing with teco networks
This commit is contained in:
parent
17d78195ee
commit
e620fc05ba
|
@ -160,10 +160,10 @@ class FlowAdjustment(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
flow = self.env['generators'][self.flow]
|
flow = self.env['generators'][self.flow]
|
||||||
flow_target = state[self.flow_target]
|
flow_target = state[self.flow_target]
|
||||||
flowed = state[self.flowed]
|
flowed = F.interpolate(state[self.flowed], size=flow_target.shape[2:], mode='bicubic')
|
||||||
flow_input = torch.stack([flow_target, flowed], dim=2)
|
flow_input = torch.stack([flow_target, flowed], dim=2)
|
||||||
flowfield = flow(flow_input)
|
flowfield = F.interpolate(flow(flow_input), size=state[self.flowed].shape[2:], mode='bicubic')
|
||||||
return {self.output: self.resample(flowed.float(), flowfield.float())}
|
return {self.output: self.resample(state[self.flowed].float(), flowfield.float())}
|
||||||
|
|
||||||
|
|
||||||
# This is the temporal discriminator loss from TecoGAN.
|
# This is the temporal discriminator loss from TecoGAN.
|
||||||
|
|
|
@ -142,6 +142,9 @@ if __name__ == "__main__":
|
||||||
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
|
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
|
||||||
img_index = opt['generator_img_index']
|
img_index = opt['generator_img_index']
|
||||||
recurrent_mode = opt['recurrent_mode']
|
recurrent_mode = opt['recurrent_mode']
|
||||||
|
if recurrent_mode:
|
||||||
|
assert opt['dataset']['batch_size'] == 1 # Can only do 1 frame at a time in recurrent mode, by definition.
|
||||||
|
scale = opt['scale']
|
||||||
first_frame = True
|
first_frame = True
|
||||||
ffmpeg_proc = None
|
ffmpeg_proc = None
|
||||||
|
|
||||||
|
@ -150,7 +153,8 @@ if __name__ == "__main__":
|
||||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||||
|
|
||||||
if recurrent_mode and first_frame:
|
if recurrent_mode and first_frame:
|
||||||
recurrent_entry = data['LQ'].detach().clone()
|
b, c, h, w = data['LQ'].shape
|
||||||
|
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device)
|
||||||
first_frame = False
|
first_frame = False
|
||||||
if recurrent_mode:
|
if recurrent_mode:
|
||||||
data['recurrent'] = recurrent_entry
|
data['recurrent'] = recurrent_entry
|
||||||
|
@ -163,7 +167,7 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
visuals = model.fake_H.detach()
|
visuals = model.fake_H.detach()
|
||||||
if recurrent_mode:
|
if recurrent_mode:
|
||||||
recurrent_entry = torch.nn.functional.interpolate(visuals, scale_factor=1/opt['scale'], mode='bicubic')
|
recurrent_entry = visuals
|
||||||
visuals = visuals.cpu().float()
|
visuals = visuals.cpu().float()
|
||||||
for i in range(visuals.shape[0]):
|
for i in range(visuals.shape[0]):
|
||||||
sr_img = util.tensor2img(visuals[i]) # uint8
|
sr_img = util.tensor2img(visuals[i]) # uint8
|
||||||
|
|
Loading…
Reference in New Issue
Block a user