forked from mrq/DL-Art-School
Make ExtensibleTrainer compatible with process_video
This commit is contained in:
parent
a18ece62ee
commit
f43df7f5f7
|
@ -130,16 +130,20 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||
self.updated = True
|
||||
|
||||
def feed_data(self, data, need_GT=False):
|
||||
def feed_data(self, data, need_GT=True):
|
||||
self.eval_state = {}
|
||||
for o in self.optimizers:
|
||||
o.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
if need_GT:
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
else:
|
||||
self.hq = self.lq
|
||||
self.ref = self.lq
|
||||
|
||||
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
for k, v in data.items():
|
||||
|
@ -228,6 +232,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
for k, v in state.items():
|
||||
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
|
||||
|
||||
# For backwards compatibility..
|
||||
self.fake_H = self.eval_state[self.opt['eval']['output_state']][0].float().cpu()
|
||||
|
||||
for net in self.netsG.values():
|
||||
net.train()
|
||||
|
||||
|
|
|
@ -552,7 +552,6 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x, ref, center_coord):
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref = self.reference_processor(ref, center_coord)
|
||||
x = self.model_fea_conv(x)
|
||||
|
||||
|
@ -561,6 +560,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
x_fea = self.feature_lr_conv(x2)
|
||||
x_fea = self.feature_lr_conv2(x_fea)
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
|
||||
identity=x_grad, output_attention_weights=True)
|
||||
|
|
|
@ -67,7 +67,10 @@ class FfmpegBackedVideoDataset(data.Dataset):
|
|||
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
|
||||
img_LQ = F.to_tensor(img_LQ)
|
||||
|
||||
return {'LQ': img_LQ}
|
||||
mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2])
|
||||
ref = torch.cat([img_LQ, mask], dim=0)
|
||||
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
||||
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
||||
|
||||
def __len__(self):
|
||||
return self.frame_count * self.vertical_splits
|
||||
|
|
Loading…
Reference in New Issue
Block a user