diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 4fa72fbc..3af72ecd 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -380,9 +380,10 @@ class SRGANModel(BaseModel): self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0) if need_GT: self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data else data['GT'] + input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] - self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] + input_pix = data['PIX'] if 'pix' in data.keys() else data['GT'] + self.pix = [t.to(self.device) for t in torch.chunk(input_pix, chunks=self.mega_batch_factor, dim=0)] if 'GAN' in data.keys(): self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)]