diff --git a/codes/test.py b/codes/test.py index 4f4498a3..a764d3f8 100644 --- a/codes/test.py +++ b/codes/test.py @@ -55,7 +55,10 @@ if __name__ == "__main__": model.feed_data(data, need_GT=need_GT) model.test() - visuals = model.fake_H.detach().float().cpu() + if isinstance(model.fake_H, tuple): + visuals = model.fake_H[0].detach().float().cpu() + else: + visuals = model.fake_H.detach().float().cpu() for i in range(visuals.shape[0]): img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i] img_name = osp.splitext(osp.basename(img_path))[0]