Fix SRGAN_model/fullimgdataset compatibility 1

This commit is contained in:
James Betker 2020-09-08 11:34:35 -06:00
parent 22c98f1567
commit 5606e8b0ee

View File

@ -380,9 +380,10 @@ class SRGANModel(BaseModel):
self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
if need_GT:
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT']
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
input_pix = data['PIX'] if 'pix' in data.keys() else data['GT']
self.pix = [t.to(self.device) for t in torch.chunk(input_pix, chunks=self.mega_batch_factor, dim=0)]
if 'GAN' in data.keys():
self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)]