Enable RRDB to take in reference inputs

This commit is contained in:
James Betker 2020-10-29 11:07:40 -06:00
parent 7d38381d46
commit 3791f95ad0

View File

@ -149,7 +149,10 @@ class RRDBNet(nn.Module):
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint self.blocks_per_checkpoint = blocks_per_checkpoint
self.scale = scale self.scale = scale
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) self.in_channels = in_channels
first_conv_stride = 1 if in_channels <= 4 else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, 1)
self.body = make_layer( self.body = make_layer(
body_block, body_block,
num_blocks, num_blocks,
@ -170,7 +173,7 @@ class RRDBNet(nn.Module):
]: ]:
default_init_weights(m, 0.1) default_init_weights(m, 0.1)
def forward(self, x): def forward(self, x, ref=None):
"""Forward function. """Forward function.
Args: Args:
@ -179,8 +182,12 @@ class RRDBNet(nn.Module):
Returns: Returns:
Tensor: Forward results. Tensor: Forward results.
""" """
if self.in_channels > 4:
feat = self.conv_first(x) x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x_lg)
x_lg = torch.cat([x_lg, ref])
feat = self.conv_first(x_lg)
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)) body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
feat = feat + body_feat feat = feat + body_feat
# upsample # upsample