diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index fc7725ce..af0139ac 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -85,6 +85,7 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C): late_stage_padding=0) self.heads = heads self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1)) + arch_util.initialize_weights(self.converter) # Accepts input of shape (b, heads, f, w, h) def forward(self, x): diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index ecb7be76..f79d4532 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -13,7 +13,7 @@ def initialize_weights(net_l, scale=1): net_l = [net_l] for net in net_l: for m in net.modules(): - if isinstance(m, nn.Conv2d): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale # for residual block if m.bias is not None: diff --git a/codes/models/networks.py b/codes/models/networks.py index cc1fb6e0..1da351dc 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -80,14 +80,7 @@ def define_G(opt, net_key='network_G'): '''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'], assembler_blocks=opt_net['assembler_blocks'])''' - netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf']) - # video restoration - elif which_model == 'EDVR': - netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], - groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], - back_RBs=opt_net['back_RBs'], center=opt_net['center'], - predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], - w_TSA=opt_net['w_TSA']) + netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])\ else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))