Fix SRGAN_model/fullimgdataset compatibility 1
This commit is contained in:
parent
22c98f1567
commit
5606e8b0ee
|
@ -380,9 +380,10 @@ class SRGANModel(BaseModel):
|
|||
self.var_L = torch.chunk(corrupted_L, chunks=self.mega_batch_factor, dim=0)
|
||||
if need_GT:
|
||||
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_pix = data['PIX'] if 'pix' in data.keys() else data['GT']
|
||||
self.pix = [t.to(self.device) for t in torch.chunk(input_pix, chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
if 'GAN' in data.keys():
|
||||
self.gan_img = [t.to(self.device) for t in torch.chunk(data['GAN'], chunks=self.mega_batch_factor, dim=0)]
|
||||
|
|
Loading…
Reference in New Issue
Block a user