diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index e1fce993..f283e1b3 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -128,7 +128,10 @@ class MultifacetedChainedEmbeddingGen(nn.Module): self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) - self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False) + if scale == 2: + self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False) + else: + self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=7, stride=4, norm=False, bias=True, activation=False) self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False) diff --git a/codes/models/base_model.py b/codes/models/base_model.py index ea08aecc..47664e8a 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -108,7 +108,9 @@ class BaseModel(): load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): - if k.startswith('module.'): + if 'teco_recurrent_process' in k: + continue + elif k.startswith('module.'): load_net_clean[k[7:]] = v else: load_net_clean[k] = v