ChainedGen 4x alteration
Increases conv window for teco_recurrent in the 4x case so all data can be used. base_model changes should be temporary.
This commit is contained in:
parent
85c07f85d9
commit
629b968901
|
@ -128,7 +128,10 @@ class MultifacetedChainedEmbeddingGen(nn.Module):
|
|||
|
||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||
|
||||
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
|
||||
if scale == 2:
|
||||
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=3, stride=2, norm=False, bias=True, activation=False)
|
||||
else:
|
||||
self.teco_recurrent_process = ConvGnLelu(3, 64, kernel_size=7, stride=4, norm=False, bias=True, activation=False)
|
||||
self.teco_recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||
|
||||
self.prog_recurrent_process = ConvGnLelu(64, 64, kernel_size=3, stride=1, norm=False, bias=True, activation=False)
|
||||
|
|
|
@ -108,7 +108,9 @@ class BaseModel():
|
|||
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
if 'teco_recurrent_process' in k:
|
||||
continue
|
||||
elif k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
|
|
Loading…
Reference in New Issue
Block a user