diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index d82f00d7..4e454a27 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -71,6 +71,8 @@ class ChainedEmbeddingGenWithStructure(nn.Module): self.recurrent = recurrent self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) if recurrent: + self.recurrent_nf = recurrent_nf + self.recurrent_stride = recurrent_stride self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) @@ -86,7 +88,12 @@ class ChainedEmbeddingGenWithStructure(nn.Module): fea = self.initial_conv(x) if self.recurrent: if recurrent is None: - recurrent = torch.zeros_like(fea) + if self.recurrent_nf == 3: + recurrent = torch.zeros_like(x) + if self.recurrent_stride != 1: + recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest') + else: + recurrent = torch.zeros_like(fea) rec = self.recurrent_process(recurrent) fea, recstd = self.recurrent_join(fea, rec) self.ref_join_std = recstd.item() @@ -101,4 +108,4 @@ class ChainedEmbeddingGenWithStructure(nn.Module): return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea def get_debug_values(self, step, net_name): - return { 'ref_join_std': self.ref_join_std } \ No newline at end of file + return { 'ref_join_std': self.ref_join_std }