Fix initialization in mhead switched rrdb

This commit is contained in:
James Betker 2020-06-15 21:32:03 -06:00
parent be7982b9ae
commit 0a714e8451
3 changed files with 3 additions and 9 deletions

View File

@ -85,6 +85,7 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C):
late_stage_padding=0) late_stage_padding=0)
self.heads = heads self.heads = heads
self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1)) self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1))
arch_util.initialize_weights(self.converter)
# Accepts input of shape (b, heads, f, w, h) # Accepts input of shape (b, heads, f, w, h)
def forward(self, x): def forward(self, x):

View File

@ -13,7 +13,7 @@ def initialize_weights(net_l, scale=1):
net_l = [net_l] net_l = [net_l]
for net in net_l: for net in net_l:
for m in net.modules(): for m in net.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in') init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block m.weight.data *= scale # for residual block
if m.bias is not None: if m.bias is not None:

View File

@ -80,14 +80,7 @@ def define_G(opt, net_key='network_G'):
'''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], '''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'], nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
assembler_blocks=opt_net['assembler_blocks'])''' assembler_blocks=opt_net['assembler_blocks'])'''
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf']) netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])\
# video restoration
elif which_model == 'EDVR':
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
w_TSA=opt_net['w_TSA'])
else: else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))