From 7c6c7a80145be5808187b0e4bc3dc8ea2b2cd29c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 9 Jan 2021 20:53:46 -0700 Subject: [PATCH] Fix process_video --- codes/process_video.py | 2 +- codes/trainer/ExtensibleTrainer.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/codes/process_video.py b/codes/process_video.py index 04cf5601..d023495f 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -170,7 +170,7 @@ if __name__ == "__main__": if recurrent_mode: data['recurrent'] = recurrent_entry - model.feed_data(data, need_GT=need_GT) + model.feed_data(data, 0, need_GT=need_GT) model.test() visuals = model.get_current_visuals()['rlt'] diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 33db1894..fb8ac6a1 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -336,9 +336,11 @@ class ExtensibleTrainer(BaseModel): def get_current_visuals(self, need_GT=True): # Conforms to an archaic format from MMSR. - return {'lq': self.eval_state['lq'][0].float().cpu(), - 'hq': self.eval_state['hq'][0].float().cpu(), - 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} + res = {'lq': self.eval_state['lq'][0].float().cpu(), + 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} + if 'hq' in self.eval_state.keys(): + res['hq'] = self.eval_state['hq'][0].float().cpu(), + return res def print_network(self): for name, net in self.networks.items():