diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index 31ea8c90..e1fce993 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -52,6 +52,75 @@ class BasicEmbeddingPyramid(nn.Module): return x, p + + +class ChainedEmbeddingGenWithStructure(nn.Module): + def __init__(self, in_nc=3, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2): + super(ChainedEmbeddingGenWithStructure, self).__init__() + self.recurrent = recurrent + self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False) + if recurrent: + self.recurrent_nf = recurrent_nf + self.recurrent_stride = recurrent_stride + self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) + self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) + self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) + self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) + self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) + self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) + self.structure_upsample = FinalUpsampleBlock2x(64) + self.grad_extract = ImageGradientNoPadding() + self.upsample = FinalUpsampleBlock2x(64) + self.ref_join_std = 0 + + def forward(self, x, recurrent=None): + fea = self.initial_conv(x) + if self.recurrent: + if recurrent is None: + if self.recurrent_nf == 3: + recurrent = torch.zeros_like(x) + if self.recurrent_stride != 1: + recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest') + else: + recurrent = torch.zeros_like(fea) + rec = self.recurrent_process(recurrent) + fea, recstd = self.recurrent_join(fea, rec) + self.ref_join_std = recstd.item() + if self.spine is not None: + emb = checkpoint(self.spine, fea) + else: + b,f,h,w = fea.shape + emb = (torch.zeros((b,f,h//2,w//2), device=fea.device), + torch.zeros((b,f,h//4,w//4), device=fea.device)) + grad = fea + for i, block in enumerate(self.blocks): + fea = fea + checkpoint(block, fea, *emb)[0] + if i < 3: + structure_br = checkpoint(self.structure_joins[i], grad, fea) + grad = grad + checkpoint(self.structure_blocks[i], structure_br) + out = checkpoint(self.upsample, fea) + return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea + + def get_debug_values(self, step, net_name): + return { 'ref_join_std': self.ref_join_std } + + +# This is a structural block that learns to mute regions of a residual transformation given a signal. +class OptionalPassthroughBlock(nn.Module): + def __init__(self, nf, initial_bias=10): + super(OptionalPassthroughBlock, self).__init__() + self.switch_process = nn.Sequential(ConvGnLelu(nf, nf // 2, 1, activation=False, norm=False, bias=False), + ConvGnLelu(nf // 2, nf // 4, 1, activation=False, norm=False, bias=False), + ConvGnLelu(nf // 4, 1, 1, activation=False, norm=False, bias=False)) + self.bias = nn.Parameter(torch.tensor(initial_bias, dtype=torch.float), requires_grad=True) + self.activation = nn.Sigmoid() + + def forward(self, x, switch_signal): + switch = self.switch_process(switch_signal) + bypass_map = self.activation(self.bias + switch) + return x * bypass_map, bypass_map + + class MultifacetedChainedEmbeddingGen(nn.Module): def __init__(self, depth=10, scale=2): super(MultifacetedChainedEmbeddingGen, self).__init__() diff --git a/codes/models/networks.py b/codes/models/networks.py index ea266a23..e10064a3 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -76,6 +76,12 @@ def define_G(opt, net_key='network_G', scale=None): netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent) + elif which_model == 'chained_gen_structured': + rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False + recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3 + recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2 + in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 + netG = chained.ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, in_nc=in_nc) elif which_model == 'multifaceted_chained': scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2 netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)