Fix process_video

This commit is contained in:
James Betker 2021-01-09 20:53:46 -07:00
parent 07168ecfb4
commit 7c6c7a8014
2 changed files with 6 additions and 4 deletions

View File

@ -170,7 +170,7 @@ if __name__ == "__main__":
if recurrent_mode: if recurrent_mode:
data['recurrent'] = recurrent_entry data['recurrent'] = recurrent_entry
model.feed_data(data, need_GT=need_GT) model.feed_data(data, 0, need_GT=need_GT)
model.test() model.test()
visuals = model.get_current_visuals()['rlt'] visuals = model.get_current_visuals()['rlt']

View File

@ -336,9 +336,11 @@ class ExtensibleTrainer(BaseModel):
def get_current_visuals(self, need_GT=True): def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR. # Conforms to an archaic format from MMSR.
return {'lq': self.eval_state['lq'][0].float().cpu(), res = {'lq': self.eval_state['lq'][0].float().cpu(),
'hq': self.eval_state['hq'][0].float().cpu(), 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} if 'hq' in self.eval_state.keys():
res['hq'] = self.eval_state['hq'][0].float().cpu(),
return res
def print_network(self): def print_network(self):
for name, net in self.networks.items(): for name, net in self.networks.items():