diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 001082a9..c0be7ba9 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -130,16 +130,20 @@ class ExtensibleTrainer(BaseModel): # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. self.updated = True - def feed_data(self, data, need_GT=False): + def feed_data(self, data, need_GT=True): self.eval_state = {} for o in self.optimizers: o.zero_grad() torch.cuda.empty_cache() self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0) - self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] - input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] - self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] + if need_GT: + self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] + input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] + self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] + else: + self.hq = self.lq + self.ref = self.lq self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} for k, v in data.items(): @@ -228,6 +232,9 @@ class ExtensibleTrainer(BaseModel): for k, v in state.items(): self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v] + # For backwards compatibility.. + self.fake_H = self.eval_state[self.opt['eval']['output_state']][0].float().cpu() + for net in self.netsG.values(): net.train() diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 6902df1e..a99292bd 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -552,7 +552,6 @@ class SwitchedSpsrWithRef2(nn.Module): self.final_temperature_step = 10000 def forward(self, x, ref, center_coord): - x_grad = self.get_g_nopadding(x) ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) @@ -561,6 +560,7 @@ class SwitchedSpsrWithRef2(nn.Module): x_fea = self.feature_lr_conv(x2) x_fea = self.feature_lr_conv2(x_fea) + x_grad = self.get_g_nopadding(x) x_grad = self.grad_conv(x_grad) x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref), identity=x_grad, output_attention_weights=True) diff --git a/codes/process_video.py b/codes/process_video.py index 21dbfeb9..aa69c242 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -67,7 +67,10 @@ class FfmpegBackedVideoDataset(data.Dataset): img_LQ = F.crop(img_LQ, 0, left, h, w_per_split) img_LQ = F.to_tensor(img_LQ) - return {'LQ': img_LQ} + mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2]) + ref = torch.cat([img_LQ, mask], dim=0) + return {'LQ': img_LQ, 'lq_fullsize_ref': ref, + 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } def __len__(self): return self.frame_count * self.vertical_splits