Mods needed to support SPSR archs with teco gan

This commit is contained in:
James Betker 2020-10-10 22:39:55 -06:00
parent 120072d464
commit e785029936
7 changed files with 62 additions and 30 deletions

View File

@ -234,10 +234,18 @@ class ExtensibleTrainer(BaseModel):
continue # This can happen for several reasons (ex: 'after' defs), just ignore it. continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
if step % self.opt['logger']['visual_debug_rate'] == 0: if step % self.opt['logger']['visual_debug_rate'] == 0:
for i, dbgv in enumerate(state[v]): for i, dbgv in enumerate(state[v]):
if dbgv.shape[1] > 3: if 'recurrent_visual_indices' in self.opt['logger'].keys():
dbgv = dbgv[:,:3,:,:] for rvi in self.opt['logger']['recurrent_visual_indices']:
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) rdbgv = dbgv[:, rvi]
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) if rdbgv.shape[1] > 3:
rdbgv = rdbgv[:, :3, :, :]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
else:
if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
def compute_fea_loss(self, real, fake): def compute_fea_loss(self, real, fake):
with torch.no_grad(): with torch.no_grad():

View File

@ -460,13 +460,19 @@ class Spsr6(nn.Module):
# Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding. # Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding.
class Spsr7(nn.Module): class Spsr7(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10): def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, recurrent=False, init_temperature=10):
super(Spsr7, self).__init__() super(Spsr7, self).__init__()
n_upscale = int(math.log(upscale, 2)) n_upscale = int(math.log(upscale, 2))
# processing the input embedding # processing the input embedding
self.reference_embedding = ReferenceImageBranch(nf) self.reference_embedding = ReferenceImageBranch(nf)
self.recurrent = recurrent
if recurrent:
self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False,
bias=True)
self.model_fea_recurrent_combine = ConvGnLelu(nf * 2, nf, 1, activation=False, norm=False, bias=False, weight_init_factor=.01)
# switch options # switch options
self.nf = nf self.nf = nf
transformation_filters = nf transformation_filters = nf
@ -522,7 +528,7 @@ class Spsr7(nn.Module):
self.final_temperature_step = 10000 self.final_temperature_step = 10000
self.lr = None self.lr = None
def forward(self, x, ref, ref_center, update_attention_norm=True): def forward(self, x, ref, ref_center, update_attention_norm=True, recurrent=None):
# The attention_maps debugger outputs <x>. Save that here. # The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu() self.lr = x.detach().cpu()
@ -531,6 +537,11 @@ class Spsr7(nn.Module):
ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
x = self.model_fea_conv(x) x = self.model_fea_conv(x)
if self.recurrent:
rec = self.model_recurrent_conv(recurrent)
br = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1))
x = x + br
x1 = x x1 = x
x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding)) x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding))

View File

@ -71,10 +71,11 @@ def define_G(opt, net_key='network_G', scale=None):
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == "spsr7": elif which_model == "spsr7":
recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
elif which_model == "spsr9": elif which_model == "spsr9":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],

View File

@ -34,8 +34,8 @@ def create_injector(opt_inject, env):
return ConcatenateInjector(opt_inject, env) return ConcatenateInjector(opt_inject, env)
elif type == 'margin_removal': elif type == 'margin_removal':
return MarginRemoval(opt_inject, env) return MarginRemoval(opt_inject, env)
elif type == 'constant': elif type == 'foreach':
return ConstantInjector(opt_inject, env) return ForEachInjector(opt_inject, env)
else: else:
raise NotImplementedError raise NotImplementedError
@ -221,18 +221,21 @@ class MarginRemoval(Injector):
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
class ConstantInjector(Injector): # Produces an injection which is composed of applying a single injector multiple times across a single dimension.
class ForEachInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super(ConstantInjector, self).__init__(opt, env) super(ForEachInjector, self).__init__(opt, env)
self.constant_type = opt['constant_type'] o = opt.copy()
self.dim = opt['dim'] o['type'] = opt['subtype']
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use. o['in'] = '_in'
o['out'] = '_out'
self.injector = create_injector(o, self.env)
def forward(self, state): def forward(self, state):
bs = state[self.like].shape[0] injs = []
dev = state[self.like].device st = state.copy()
if self.constant_type == 'zeros': inputs = state[self.opt['in']]
out = torch.zeros((bs,) + tuple(self.dim), device=dev) for i in range(inputs.shape[1]):
else: st['_in'] = inputs[:, i]
raise NotImplementedError injs.append(self.injector(st)['_out'])
return { self.opt['out']: out } return {self.output: torch.stack(injs, dim=1)}

View File

@ -394,7 +394,7 @@ class RecurrentLoss(ConfigurableLoss):
real = state[self.opt['real']] real = state[self.opt['real']]
for i in range(real.shape[1]): for i in range(real.shape[1]):
st['_real'] = real[:, i] st['_real'] = real[:, i]
st['_fake'] = state[self.opt['fake']][i] st['_fake'] = state[self.opt['fake']][:, i]
total_loss += self.loss(net, st) total_loss += self.loss(net, st)
return total_loss return total_loss
@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss):
def forward(self, net, state): def forward(self, net, state):
st = state.copy() st = state.copy()
st['_real'] = state[self.opt['real']][:, self.index] st['_real'] = state[self.opt['real']][:, self.index]
st['_fake'] = state[self.opt['fake']][self.index] st['_fake'] = state[self.opt['fake']][:, self.index]
return self.loss(net, st) return self.loss(net, st)

View File

@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
self.flow = opt['flow_network'] self.flow = opt['flow_network']
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0
self.recurrent_index = opt['recurrent_index'] self.recurrent_index = opt['recurrent_index']
self.scale = opt['scale'] self.scale = opt['scale']
self.resample = Resample2d() self.resample = Resample2d()
@ -71,12 +71,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
flow = self.env['generators'][self.flow] flow = self.env['generators'][self.flow]
results = []
first_inputs = extract_params_from_state(self.first_inputs, state) first_inputs = extract_params_from_state(self.first_inputs, state)
inputs = extract_params_from_state(self.input, state) inputs = extract_params_from_state(self.input, state)
if not isinstance(inputs, list): if not isinstance(inputs, list):
inputs = [inputs] inputs = [inputs]
if not isinstance(self.output, list):
self.output = [self.output]
results = {}
for out_key in self.output:
results[out_key] = []
# Go forward in the sequence first. # Go forward in the sequence first.
first_step = True first_step = True
b, f, c, h, w = inputs[self.input_lq_index].shape b, f, c, h, w = inputs[self.input_lq_index].shape
@ -101,8 +106,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
recurrent_input = gen_out[self.output_hq_index] recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input)
# Now go backwards, skipping the last element (it's already stored in recurrent_input) # Now go backwards, skipping the last element (it's already stored in recurrent_input)
if self.do_backwards: if self.do_backwards:
@ -122,10 +128,13 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out] gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
recurrent_input = gen_out[self.output_hq_index] recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input)
return {self.output: results} for k, v in results.items():
results[k] = torch.stack(v, dim=1)
return results
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it): def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
if self.env['rank'] > 0: if self.env['rank'] > 0:
@ -183,7 +192,7 @@ class TecoGanLoss(ConfigurableLoss):
net = self.env['discriminators'][self.opt['discriminator']] net = self.env['discriminators'][self.opt['discriminator']]
flow_gen = self.env['generators'][self.image_flow_generator] flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']] real = state[self.opt['real']]
fake = torch.stack(state[self.opt['fake']], dim=1) fake = state[self.opt['fake']]
sequence_len = real.shape[1] sequence_len = real.shape[1]
lr = state[self.opt['lr_inputs']] lr = state[self.opt['lr_inputs']]
l_total = 0 l_total = 0

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_spsr7.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()