42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
import torch
|
|
from torch import nn
|
|
from lambda_networks import LambdaLayer
|
|
from torch.nn import GroupNorm
|
|
|
|
from models.archs.RRDBNet_arch import ResidualDenseBlock
|
|
|
|
|
|
class LambdaRRDB(nn.Module):
|
|
"""Residual in Residual Dense Block.
|
|
|
|
Used in RRDB-Net in ESRGAN.
|
|
|
|
Args:
|
|
mid_channels (int): Channel number of intermediate features.
|
|
growth_channels (int): Channels for each growth.
|
|
"""
|
|
|
|
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
|
super(LambdaRRDB, self).__init__()
|
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
|
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels, init_weight=1)
|
|
if reduce_to is None:
|
|
reduce_to = mid_channels
|
|
self.lam = LambdaLayer(dim=mid_channels, dim_out=reduce_to, r=23, dim_k=16, heads=4, dim_u=4)
|
|
self.gn = GroupNorm(num_groups=8, num_channels=mid_channels)
|
|
self.scale = nn.Parameter(torch.full((1,), 1/256))
|
|
|
|
def forward(self, x):
|
|
"""Forward function.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
|
|
|
Returns:
|
|
Tensor: Forward results.
|
|
"""
|
|
out = self.rdb1(x)
|
|
out = self.rdb2(out)
|
|
out = self.lam(out)
|
|
out = self.gn(out)
|
|
return out * self.scale + x |