Support silu activation

This commit is contained in:
James Betker 2021-08-25 09:03:14 -06:00
parent 67bf7f5219
commit 08b33c8e3a

View File

@ -28,13 +28,13 @@ def eval_decorator(fn):
class ResBlock(nn.Module):
def __init__(self, chan, conv):
def __init__(self, chan, conv, activation):
super().__init__()
self.net = nn.Sequential(
conv(chan, chan, 3, padding = 1),
nn.ReLU(),
activation(),
conv(chan, chan, 3, padding = 1),
nn.ReLU(),
activation(),
conv(chan, chan, 1)
)
@ -69,6 +69,7 @@ class DiscreteVAE(nn.Module):
kernel_size = 4,
use_transposed_convs = True,
encoder_norm = False,
activation = 'relu',
smooth_l1_loss = False,
straight_through = False,
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
@ -94,6 +95,14 @@ class DiscreteVAE(nn.Module):
if not use_transposed_convs:
conv_transpose = functools.partial(UpsampledConv, conv)
if activation == 'relu':
act = nn.ReLU
elif activation == 'silu':
act = nn.SiLU
else:
assert NotImplementedError()
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
dec_chans = list(reversed(enc_chans))
@ -109,14 +118,14 @@ class DiscreteVAE(nn.Module):
pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
if encoder_norm:
enc_layers.append(nn.GroupNorm(8, enc_out))
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
enc_layers.append(ResBlock(enc_chans[-1], conv))
dec_layers.insert(0, ResBlock(dec_chans[1], conv, act))
enc_layers.append(ResBlock(enc_chans[-1], conv, act))
if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))