From 5606e8b0ee008bb101b437feeb2f6ac53003220c Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 8 Sep 2020 11:34:35 -0600 Subject: [PATCH] Fix SRGAN_model/fullimgdataset compatibility 1 --- codes/models/SRGAN_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 4fa72fbc..3af72ecd 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -380,9 +380,10 @@ class SRGANModel(BaseModel): self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0) if need_GT: self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data else data['GT'] + input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] - self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] + input_pix = data['PIX'] if 'pix' in data.keys() else data['GT'] + self.pix = [t.to(self.device) for t in torch.chunk(input_pix, chunks=self.mega_batch_factor, dim=0)] if 'GAN' in data.keys(): self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)]