forked from mrq/DL-Art-School
Make ExtensibleTrainer compatible with process_video
This commit is contained in:
parent
a18ece62ee
commit
f43df7f5f7
|
@ -130,16 +130,20 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||||
self.updated = True
|
self.updated = True
|
||||||
|
|
||||||
def feed_data(self, data, need_GT=False):
|
def feed_data(self, data, need_GT=True):
|
||||||
self.eval_state = {}
|
self.eval_state = {}
|
||||||
for o in self.optimizers:
|
for o in self.optimizers:
|
||||||
o.zero_grad()
|
o.zero_grad()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
if need_GT:
|
||||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||||
|
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||||
|
else:
|
||||||
|
self.hq = self.lq
|
||||||
|
self.ref = self.lq
|
||||||
|
|
||||||
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
@ -228,6 +232,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
|
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
|
||||||
|
|
||||||
|
# For backwards compatibility..
|
||||||
|
self.fake_H = self.eval_state[self.opt['eval']['output_state']][0].float().cpu()
|
||||||
|
|
||||||
for net in self.netsG.values():
|
for net in self.netsG.values():
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
|
|
|
@ -552,7 +552,6 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
self.final_temperature_step = 10000
|
self.final_temperature_step = 10000
|
||||||
|
|
||||||
def forward(self, x, ref, center_coord):
|
def forward(self, x, ref, center_coord):
|
||||||
x_grad = self.get_g_nopadding(x)
|
|
||||||
ref = self.reference_processor(ref, center_coord)
|
ref = self.reference_processor(ref, center_coord)
|
||||||
x = self.model_fea_conv(x)
|
x = self.model_fea_conv(x)
|
||||||
|
|
||||||
|
@ -561,6 +560,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
x_fea = self.feature_lr_conv(x2)
|
x_fea = self.feature_lr_conv(x2)
|
||||||
x_fea = self.feature_lr_conv2(x_fea)
|
x_fea = self.feature_lr_conv2(x_fea)
|
||||||
|
|
||||||
|
x_grad = self.get_g_nopadding(x)
|
||||||
x_grad = self.grad_conv(x_grad)
|
x_grad = self.grad_conv(x_grad)
|
||||||
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
|
x_grad, a3 = self.sw_grad((torch.cat([x_grad, x1], dim=1), ref),
|
||||||
identity=x_grad, output_attention_weights=True)
|
identity=x_grad, output_attention_weights=True)
|
||||||
|
|
|
@ -67,7 +67,10 @@ class FfmpegBackedVideoDataset(data.Dataset):
|
||||||
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
|
img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
|
||||||
img_LQ = F.to_tensor(img_LQ)
|
img_LQ = F.to_tensor(img_LQ)
|
||||||
|
|
||||||
return {'LQ': img_LQ}
|
mask = torch.ones(1, img_LQ.shape[1], img_LQ.shape[2])
|
||||||
|
ref = torch.cat([img_LQ, mask], dim=0)
|
||||||
|
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
||||||
|
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.frame_count * self.vertical_splits
|
return self.frame_count * self.vertical_splits
|
||||||
|
|
Loading…
Reference in New Issue
Block a user