Fix initialization in mhead switched rrdb
This commit is contained in:
parent
be7982b9ae
commit
0a714e8451
|
@ -85,6 +85,7 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C):
|
||||||
late_stage_padding=0)
|
late_stage_padding=0)
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1))
|
self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1))
|
||||||
|
arch_util.initialize_weights(self.converter)
|
||||||
|
|
||||||
# Accepts input of shape (b, heads, f, w, h)
|
# Accepts input of shape (b, heads, f, w, h)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -13,7 +13,7 @@ def initialize_weights(net_l, scale=1):
|
||||||
net_l = [net_l]
|
net_l = [net_l]
|
||||||
for net in net_l:
|
for net in net_l:
|
||||||
for m in net.modules():
|
for m in net.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
m.weight.data *= scale # for residual block
|
m.weight.data *= scale # for residual block
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
|
|
|
@ -80,14 +80,7 @@ def define_G(opt, net_key='network_G'):
|
||||||
'''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
'''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
||||||
assembler_blocks=opt_net['assembler_blocks'])'''
|
assembler_blocks=opt_net['assembler_blocks'])'''
|
||||||
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])\
|
||||||
# video restoration
|
|
||||||
elif which_model == 'EDVR':
|
|
||||||
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
|
||||||
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
|
|
||||||
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
|
|
||||||
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
|
|
||||||
w_TSA=opt_net['w_TSA'])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user