Enable RRDB to take in reference inputs
This commit is contained in:
parent
7d38381d46
commit
3791f95ad0
|
@ -149,7 +149,10 @@ class RRDBNet(nn.Module):
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
self.in_channels = in_channels
|
||||||
|
first_conv_stride = 1 if in_channels <= 4 else scale
|
||||||
|
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||||
|
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, 1)
|
||||||
self.body = make_layer(
|
self.body = make_layer(
|
||||||
body_block,
|
body_block,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
@ -170,7 +173,7 @@ class RRDBNet(nn.Module):
|
||||||
]:
|
]:
|
||||||
default_init_weights(m, 0.1)
|
default_init_weights(m, 0.1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, ref=None):
|
||||||
"""Forward function.
|
"""Forward function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -179,8 +182,12 @@ class RRDBNet(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Forward results.
|
Tensor: Forward results.
|
||||||
"""
|
"""
|
||||||
|
if self.in_channels > 4:
|
||||||
feat = self.conv_first(x)
|
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||||
|
if ref is None:
|
||||||
|
ref = torch.zeros_like(x_lg)
|
||||||
|
x_lg = torch.cat([x_lg, ref])
|
||||||
|
feat = self.conv_first(x_lg)
|
||||||
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
||||||
feat = feat + body_feat
|
feat = feat + body_feat
|
||||||
# upsample
|
# upsample
|
||||||
|
|
Loading…
Reference in New Issue
Block a user