Mods needed to support SPSR archs with teco gan
This commit is contained in:
parent
120072d464
commit
e785029936
|
@ -234,10 +234,18 @@ class ExtensibleTrainer(BaseModel):
|
|||
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
||||
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
for i, dbgv in enumerate(state[v]):
|
||||
if dbgv.shape[1] > 3:
|
||||
dbgv = dbgv[:,:3,:,:]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
if 'recurrent_visual_indices' in self.opt['logger'].keys():
|
||||
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||
rdbgv = dbgv[:, rvi]
|
||||
if rdbgv.shape[1] > 3:
|
||||
rdbgv = rdbgv[:, :3, :, :]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||
else:
|
||||
if dbgv.shape[1] > 3:
|
||||
dbgv = dbgv[:,:3,:,:]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -460,13 +460,19 @@ class Spsr6(nn.Module):
|
|||
|
||||
# Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding.
|
||||
class Spsr7(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, recurrent=False, init_temperature=10):
|
||||
super(Spsr7, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
# processing the input embedding
|
||||
self.reference_embedding = ReferenceImageBranch(nf)
|
||||
|
||||
self.recurrent = recurrent
|
||||
if recurrent:
|
||||
self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False,
|
||||
bias=True)
|
||||
self.model_fea_recurrent_combine = ConvGnLelu(nf * 2, nf, 1, activation=False, norm=False, bias=False, weight_init_factor=.01)
|
||||
|
||||
# switch options
|
||||
self.nf = nf
|
||||
transformation_filters = nf
|
||||
|
@ -522,7 +528,7 @@ class Spsr7(nn.Module):
|
|||
self.final_temperature_step = 10000
|
||||
self.lr = None
|
||||
|
||||
def forward(self, x, ref, ref_center, update_attention_norm=True):
|
||||
def forward(self, x, ref, ref_center, update_attention_norm=True, recurrent=None):
|
||||
# The attention_maps debugger outputs <x>. Save that here.
|
||||
self.lr = x.detach().cpu()
|
||||
|
||||
|
@ -531,6 +537,11 @@ class Spsr7(nn.Module):
|
|||
ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
|
||||
|
||||
x = self.model_fea_conv(x)
|
||||
if self.recurrent:
|
||||
rec = self.model_recurrent_conv(recurrent)
|
||||
br = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1))
|
||||
x = x + br
|
||||
|
||||
x1 = x
|
||||
x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding))
|
||||
|
||||
|
|
|
@ -71,10 +71,11 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
elif which_model == "spsr7":
|
||||
recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
|
||||
elif which_model == "spsr9":
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
|
|
|
@ -34,8 +34,8 @@ def create_injector(opt_inject, env):
|
|||
return ConcatenateInjector(opt_inject, env)
|
||||
elif type == 'margin_removal':
|
||||
return MarginRemoval(opt_inject, env)
|
||||
elif type == 'constant':
|
||||
return ConstantInjector(opt_inject, env)
|
||||
elif type == 'foreach':
|
||||
return ForEachInjector(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -221,18 +221,21 @@ class MarginRemoval(Injector):
|
|||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||
|
||||
|
||||
class ConstantInjector(Injector):
|
||||
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
|
||||
class ForEachInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ConstantInjector, self).__init__(opt, env)
|
||||
self.constant_type = opt['constant_type']
|
||||
self.dim = opt['dim']
|
||||
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
||||
super(ForEachInjector, self).__init__(opt, env)
|
||||
o = opt.copy()
|
||||
o['type'] = opt['subtype']
|
||||
o['in'] = '_in'
|
||||
o['out'] = '_out'
|
||||
self.injector = create_injector(o, self.env)
|
||||
|
||||
def forward(self, state):
|
||||
bs = state[self.like].shape[0]
|
||||
dev = state[self.like].device
|
||||
if self.constant_type == 'zeros':
|
||||
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return { self.opt['out']: out }
|
||||
injs = []
|
||||
st = state.copy()
|
||||
inputs = state[self.opt['in']]
|
||||
for i in range(inputs.shape[1]):
|
||||
st['_in'] = inputs[:, i]
|
||||
injs.append(self.injector(st)['_out'])
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
|
|
|
@ -394,7 +394,7 @@ class RecurrentLoss(ConfigurableLoss):
|
|||
real = state[self.opt['real']]
|
||||
for i in range(real.shape[1]):
|
||||
st['_real'] = real[:, i]
|
||||
st['_fake'] = state[self.opt['fake']][i]
|
||||
st['_fake'] = state[self.opt['fake']][:, i]
|
||||
total_loss += self.loss(net, st)
|
||||
return total_loss
|
||||
|
||||
|
@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss):
|
|||
def forward(self, net, state):
|
||||
st = state.copy()
|
||||
st['_real'] = state[self.opt['real']][:, self.index]
|
||||
st['_fake'] = state[self.opt['fake']][self.index]
|
||||
st['_fake'] = state[self.opt['fake']][:, self.index]
|
||||
return self.loss(net, st)
|
||||
|
|
|
@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||
self.flow = opt['flow_network']
|
||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||
self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0
|
||||
self.recurrent_index = opt['recurrent_index']
|
||||
self.scale = opt['scale']
|
||||
self.resample = Resample2d()
|
||||
|
@ -71,12 +71,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
flow = self.env['generators'][self.flow]
|
||||
results = []
|
||||
first_inputs = extract_params_from_state(self.first_inputs, state)
|
||||
inputs = extract_params_from_state(self.input, state)
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
|
||||
if not isinstance(self.output, list):
|
||||
self.output = [self.output]
|
||||
results = {}
|
||||
for out_key in self.output:
|
||||
results[out_key] = []
|
||||
|
||||
# Go forward in the sequence first.
|
||||
first_step = True
|
||||
b, f, c, h, w = inputs[self.input_lq_index].shape
|
||||
|
@ -101,8 +106,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
gen_out = gen(*input)
|
||||
if isinstance(gen_out, torch.Tensor):
|
||||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
results[out_key].append(gen_out[i])
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
results.append(recurrent_input)
|
||||
|
||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||
if self.do_backwards:
|
||||
|
@ -122,10 +128,13 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
gen_out = gen(*input)
|
||||
if isinstance(gen_out, torch.Tensor):
|
||||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
results[out_key].append(gen_out[i])
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
results.append(recurrent_input)
|
||||
|
||||
return {self.output: results}
|
||||
for k, v in results.items():
|
||||
results[k] = torch.stack(v, dim=1)
|
||||
return results
|
||||
|
||||
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
||||
if self.env['rank'] > 0:
|
||||
|
@ -183,7 +192,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
net = self.env['discriminators'][self.opt['discriminator']]
|
||||
flow_gen = self.env['generators'][self.image_flow_generator]
|
||||
real = state[self.opt['real']]
|
||||
fake = torch.stack(state[self.opt['fake']], dim=1)
|
||||
fake = state[self.opt['fake']]
|
||||
sequence_len = real.shape[1]
|
||||
lr = state[self.opt['lr_inputs']]
|
||||
l_total = 0
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_spsr7.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user