Fix recurrent=None bug in ChainedEmbeddingGen

This commit is contained in:
James Betker 2020-10-19 15:25:12 -06:00
parent 331c40f0c8
commit 1b1ca297f8

View File

@ -71,6 +71,8 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
self.recurrent = recurrent self.recurrent = recurrent
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
if recurrent: if recurrent:
self.recurrent_nf = recurrent_nf
self.recurrent_stride = recurrent_stride
self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False)
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
@ -86,7 +88,12 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
fea = self.initial_conv(x) fea = self.initial_conv(x)
if self.recurrent: if self.recurrent:
if recurrent is None: if recurrent is None:
recurrent = torch.zeros_like(fea) if self.recurrent_nf == 3:
recurrent = torch.zeros_like(x)
if self.recurrent_stride != 1:
recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest')
else:
recurrent = torch.zeros_like(fea)
rec = self.recurrent_process(recurrent) rec = self.recurrent_process(recurrent)
fea, recstd = self.recurrent_join(fea, rec) fea, recstd = self.recurrent_join(fea, rec)
self.ref_join_std = recstd.item() self.ref_join_std = recstd.item()
@ -101,4 +108,4 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea
def get_debug_values(self, step, net_name): def get_debug_values(self, step, net_name):
return { 'ref_join_std': self.ref_join_std } return { 'ref_join_std': self.ref_join_std }