Make ExtensibleTrainer compatible with process_video

This commit is contained in:
James Betker 2020-09-08 08:03:41 -06:00
parent a18ece62ee
commit f43df7f5f7
3 changed files with 16 additions and 6 deletions

View File

@ -130,16 +130,20 @@ class ExtensibleTrainer(BaseModel):
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True
def feed_data(self, data, need_GT=False):
def feed_data(self, data, need_GT=True):
self.eval_state = {}
for o in self.optimizers:
o.zero_grad()
torch.cuda.empty_cache()
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
if need_GT:
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
else:
self.hq = self.lq
self.ref = self.lq
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
for k, v in data.items():
@ -228,6 +232,9 @@ class ExtensibleTrainer(BaseModel):
for k, v in state.items():
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
# For backwards compatibility..
self.fake_H = self.eval_state[self.opt['eval']['output_state']][0].float().cpu()
for net in self.netsG.values():
net.train()

View File

@ -552,7 +552,6 @@ class SwitchedSpsrWithRef2(nn.Module):
self.final_temperature_step = 10000
def forward(self, x, ref, center_coord):
x_grad = self.get_g_nopadding(x)
ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x)
@ -561,6 +560,7 @@ class SwitchedSpsrWithRef2(nn.Module):
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_lr_conv2(x_fea)
x_grad = self.get_g_nopadding(x)
x_grad = self.grad_conv(x_grad)
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
identity=x_grad, output_attention_weights=True)

View File

@ -67,7 +67,10 @@ class FfmpegBackedVideoDataset(data.Dataset):
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
img_LQ = F.to_tensor(img_LQ)
return {'LQ': img_LQ}
mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2])
ref = torch.cat([img_LQ, mask], dim=0)
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
def __len__(self):
return self.frame_count * self.vertical_splits