ChainedGen 4x alteration

Increases conv window for teco_recurrent in the 4x case so all data
can be used.

base_model changes should be temporary.
This commit is contained in:
James Betker 2020-10-26 10:54:51 -06:00
parent 85c07f85d9
commit 629b968901
2 changed files with 7 additions and 2 deletions

View File

@ -128,7 +128,10 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
if scale == 2:
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
else:
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=7, stride=4, norm=False, bias=True, activation=False)
self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)

View File

@ -108,7 +108,9 @@ class BaseModel():
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
if 'teco_recurrent_process' in k:
continue
elif k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v