Fix scale torch warning

This commit is contained in:
James Betker 2020-07-31 16:56:04 -06:00
parent bcebed19b7
commit 8dd44182e6

View File

@ -18,7 +18,7 @@ class MultiConvBlock(nn.Module):
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] + self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] + [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)]) [ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init, dtype=torch.float))
self.bias = nn.Parameter(torch.zeros(1)) self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x, noise=None): def forward(self, x, noise=None):