forked from mrq/DL-Art-School
Mods needed to support SPSR archs with teco gan
This commit is contained in:
parent
120072d464
commit
e785029936
|
@ -234,10 +234,18 @@ class ExtensibleTrainer(BaseModel):
|
||||||
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
||||||
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
for i, dbgv in enumerate(state[v]):
|
for i, dbgv in enumerate(state[v]):
|
||||||
if dbgv.shape[1] > 3:
|
if 'recurrent_visual_indices' in self.opt['logger'].keys():
|
||||||
dbgv = dbgv[:,:3,:,:]
|
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
rdbgv = dbgv[:, rvi]
|
||||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
if rdbgv.shape[1] > 3:
|
||||||
|
rdbgv = rdbgv[:, :3, :, :]
|
||||||
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
|
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||||
|
else:
|
||||||
|
if dbgv.shape[1] > 3:
|
||||||
|
dbgv = dbgv[:,:3,:,:]
|
||||||
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
|
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||||
|
|
||||||
def compute_fea_loss(self, real, fake):
|
def compute_fea_loss(self, real, fake):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -460,13 +460,19 @@ class Spsr6(nn.Module):
|
||||||
|
|
||||||
# Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding.
|
# Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding.
|
||||||
class Spsr7(nn.Module):
|
class Spsr7(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10):
|
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, recurrent=False, init_temperature=10):
|
||||||
super(Spsr7, self).__init__()
|
super(Spsr7, self).__init__()
|
||||||
n_upscale = int(math.log(upscale, 2))
|
n_upscale = int(math.log(upscale, 2))
|
||||||
|
|
||||||
# processing the input embedding
|
# processing the input embedding
|
||||||
self.reference_embedding = ReferenceImageBranch(nf)
|
self.reference_embedding = ReferenceImageBranch(nf)
|
||||||
|
|
||||||
|
self.recurrent = recurrent
|
||||||
|
if recurrent:
|
||||||
|
self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False,
|
||||||
|
bias=True)
|
||||||
|
self.model_fea_recurrent_combine = ConvGnLelu(nf * 2, nf, 1, activation=False, norm=False, bias=False, weight_init_factor=.01)
|
||||||
|
|
||||||
# switch options
|
# switch options
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
transformation_filters = nf
|
transformation_filters = nf
|
||||||
|
@ -522,7 +528,7 @@ class Spsr7(nn.Module):
|
||||||
self.final_temperature_step = 10000
|
self.final_temperature_step = 10000
|
||||||
self.lr = None
|
self.lr = None
|
||||||
|
|
||||||
def forward(self, x, ref, ref_center, update_attention_norm=True):
|
def forward(self, x, ref, ref_center, update_attention_norm=True, recurrent=None):
|
||||||
# The attention_maps debugger outputs <x>. Save that here.
|
# The attention_maps debugger outputs <x>. Save that here.
|
||||||
self.lr = x.detach().cpu()
|
self.lr = x.detach().cpu()
|
||||||
|
|
||||||
|
@ -531,6 +537,11 @@ class Spsr7(nn.Module):
|
||||||
ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
|
ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
|
||||||
|
|
||||||
x = self.model_fea_conv(x)
|
x = self.model_fea_conv(x)
|
||||||
|
if self.recurrent:
|
||||||
|
rec = self.model_recurrent_conv(recurrent)
|
||||||
|
br = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1))
|
||||||
|
x = x + br
|
||||||
|
|
||||||
x1 = x
|
x1 = x
|
||||||
x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding))
|
x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding))
|
||||||
|
|
||||||
|
|
|
@ -71,10 +71,11 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||||
elif which_model == "spsr7":
|
elif which_model == "spsr7":
|
||||||
|
recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
|
||||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||||
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||||
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
|
||||||
elif which_model == "spsr9":
|
elif which_model == "spsr9":
|
||||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||||
netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||||
|
|
|
@ -34,8 +34,8 @@ def create_injector(opt_inject, env):
|
||||||
return ConcatenateInjector(opt_inject, env)
|
return ConcatenateInjector(opt_inject, env)
|
||||||
elif type == 'margin_removal':
|
elif type == 'margin_removal':
|
||||||
return MarginRemoval(opt_inject, env)
|
return MarginRemoval(opt_inject, env)
|
||||||
elif type == 'constant':
|
elif type == 'foreach':
|
||||||
return ConstantInjector(opt_inject, env)
|
return ForEachInjector(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -221,18 +221,21 @@ class MarginRemoval(Injector):
|
||||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||||
|
|
||||||
|
|
||||||
class ConstantInjector(Injector):
|
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
|
||||||
|
class ForEachInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(ConstantInjector, self).__init__(opt, env)
|
super(ForEachInjector, self).__init__(opt, env)
|
||||||
self.constant_type = opt['constant_type']
|
o = opt.copy()
|
||||||
self.dim = opt['dim']
|
o['type'] = opt['subtype']
|
||||||
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
o['in'] = '_in'
|
||||||
|
o['out'] = '_out'
|
||||||
|
self.injector = create_injector(o, self.env)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
bs = state[self.like].shape[0]
|
injs = []
|
||||||
dev = state[self.like].device
|
st = state.copy()
|
||||||
if self.constant_type == 'zeros':
|
inputs = state[self.opt['in']]
|
||||||
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
for i in range(inputs.shape[1]):
|
||||||
else:
|
st['_in'] = inputs[:, i]
|
||||||
raise NotImplementedError
|
injs.append(self.injector(st)['_out'])
|
||||||
return { self.opt['out']: out }
|
return {self.output: torch.stack(injs, dim=1)}
|
||||||
|
|
|
@ -394,7 +394,7 @@ class RecurrentLoss(ConfigurableLoss):
|
||||||
real = state[self.opt['real']]
|
real = state[self.opt['real']]
|
||||||
for i in range(real.shape[1]):
|
for i in range(real.shape[1]):
|
||||||
st['_real'] = real[:, i]
|
st['_real'] = real[:, i]
|
||||||
st['_fake'] = state[self.opt['fake']][i]
|
st['_fake'] = state[self.opt['fake']][:, i]
|
||||||
total_loss += self.loss(net, st)
|
total_loss += self.loss(net, st)
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss):
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
st = state.copy()
|
st = state.copy()
|
||||||
st['_real'] = state[self.opt['real']][:, self.index]
|
st['_real'] = state[self.opt['real']][:, self.index]
|
||||||
st['_fake'] = state[self.opt['fake']][self.index]
|
st['_fake'] = state[self.opt['fake']][:, self.index]
|
||||||
return self.loss(net, st)
|
return self.loss(net, st)
|
||||||
|
|
|
@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||||
self.flow = opt['flow_network']
|
self.flow = opt['flow_network']
|
||||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0
|
||||||
self.recurrent_index = opt['recurrent_index']
|
self.recurrent_index = opt['recurrent_index']
|
||||||
self.scale = opt['scale']
|
self.scale = opt['scale']
|
||||||
self.resample = Resample2d()
|
self.resample = Resample2d()
|
||||||
|
@ -71,12 +71,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
flow = self.env['generators'][self.flow]
|
flow = self.env['generators'][self.flow]
|
||||||
results = []
|
|
||||||
first_inputs = extract_params_from_state(self.first_inputs, state)
|
first_inputs = extract_params_from_state(self.first_inputs, state)
|
||||||
inputs = extract_params_from_state(self.input, state)
|
inputs = extract_params_from_state(self.input, state)
|
||||||
if not isinstance(inputs, list):
|
if not isinstance(inputs, list):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
|
|
||||||
|
if not isinstance(self.output, list):
|
||||||
|
self.output = [self.output]
|
||||||
|
results = {}
|
||||||
|
for out_key in self.output:
|
||||||
|
results[out_key] = []
|
||||||
|
|
||||||
# Go forward in the sequence first.
|
# Go forward in the sequence first.
|
||||||
first_step = True
|
first_step = True
|
||||||
b, f, c, h, w = inputs[self.input_lq_index].shape
|
b, f, c, h, w = inputs[self.input_lq_index].shape
|
||||||
|
@ -101,8 +106,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
gen_out = [gen_out]
|
gen_out = [gen_out]
|
||||||
|
for i, out_key in enumerate(self.output):
|
||||||
|
results[out_key].append(gen_out[i])
|
||||||
recurrent_input = gen_out[self.output_hq_index]
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
results.append(recurrent_input)
|
|
||||||
|
|
||||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||||
if self.do_backwards:
|
if self.do_backwards:
|
||||||
|
@ -122,10 +128,13 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
gen_out = [gen_out]
|
gen_out = [gen_out]
|
||||||
|
for i, out_key in enumerate(self.output):
|
||||||
|
results[out_key].append(gen_out[i])
|
||||||
recurrent_input = gen_out[self.output_hq_index]
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
results.append(recurrent_input)
|
|
||||||
|
|
||||||
return {self.output: results}
|
for k, v in results.items():
|
||||||
|
results[k] = torch.stack(v, dim=1)
|
||||||
|
return results
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
|
@ -183,7 +192,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
net = self.env['discriminators'][self.opt['discriminator']]
|
net = self.env['discriminators'][self.opt['discriminator']]
|
||||||
flow_gen = self.env['generators'][self.image_flow_generator]
|
flow_gen = self.env['generators'][self.image_flow_generator]
|
||||||
real = state[self.opt['real']]
|
real = state[self.opt['real']]
|
||||||
fake = torch.stack(state[self.opt['fake']], dim=1)
|
fake = state[self.opt['fake']]
|
||||||
sequence_len = real.shape[1]
|
sequence_len = real.shape[1]
|
||||||
lr = state[self.opt['lr_inputs']]
|
lr = state[self.opt['lr_inputs']]
|
||||||
l_total = 0
|
l_total = 0
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_spsr7.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user