Fix process_video
This commit is contained in:
parent
07168ecfb4
commit
7c6c7a8014
|
@ -170,7 +170,7 @@ if __name__ == "__main__":
|
|||
if recurrent_mode:
|
||||
data['recurrent'] = recurrent_entry
|
||||
|
||||
model.feed_data(data, need_GT=need_GT)
|
||||
model.feed_data(data, 0, need_GT=need_GT)
|
||||
model.test()
|
||||
visuals = model.get_current_visuals()['rlt']
|
||||
|
||||
|
|
|
@ -336,9 +336,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
# Conforms to an archaic format from MMSR.
|
||||
return {'lq': self.eval_state['lq'][0].float().cpu(),
|
||||
'hq': self.eval_state['hq'][0].float().cpu(),
|
||||
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||
res = {'lq': self.eval_state['lq'][0].float().cpu(),
|
||||
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||
if 'hq' in self.eval_state.keys():
|
||||
res['hq'] = self.eval_state['hq'][0].float().cpu(),
|
||||
return res
|
||||
|
||||
def print_network(self):
|
||||
for name, net in self.networks.items():
|
||||
|
|
Loading…
Reference in New Issue
Block a user