Support silu activation
This commit is contained in:
parent
67bf7f5219
commit
08b33c8e3a
|
@ -28,13 +28,13 @@ def eval_decorator(fn):
|
|||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan, conv):
|
||||
def __init__(self, chan, conv, activation):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
conv(chan, chan, 3, padding = 1),
|
||||
nn.ReLU(),
|
||||
activation(),
|
||||
conv(chan, chan, 3, padding = 1),
|
||||
nn.ReLU(),
|
||||
activation(),
|
||||
conv(chan, chan, 1)
|
||||
)
|
||||
|
||||
|
@ -69,6 +69,7 @@ class DiscreteVAE(nn.Module):
|
|||
kernel_size = 4,
|
||||
use_transposed_convs = True,
|
||||
encoder_norm = False,
|
||||
activation = 'relu',
|
||||
smooth_l1_loss = False,
|
||||
straight_through = False,
|
||||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||
|
@ -94,6 +95,14 @@ class DiscreteVAE(nn.Module):
|
|||
if not use_transposed_convs:
|
||||
conv_transpose = functools.partial(UpsampledConv, conv)
|
||||
|
||||
if activation == 'relu':
|
||||
act = nn.ReLU
|
||||
elif activation == 'silu':
|
||||
act = nn.SiLU
|
||||
else:
|
||||
assert NotImplementedError()
|
||||
|
||||
|
||||
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
||||
dec_chans = list(reversed(enc_chans))
|
||||
|
||||
|
@ -109,14 +118,14 @@ class DiscreteVAE(nn.Module):
|
|||
|
||||
pad = (kernel_size - 1) // 2
|
||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
|
||||
if encoder_norm:
|
||||
enc_layers.append(nn.GroupNorm(8, enc_out))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
|
||||
|
||||
for _ in range(num_resnet_blocks):
|
||||
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
|
||||
enc_layers.append(ResBlock(enc_chans[-1], conv))
|
||||
dec_layers.insert(0, ResBlock(dec_chans[1], conv, act))
|
||||
enc_layers.append(ResBlock(enc_chans[-1], conv, act))
|
||||
|
||||
if num_resnet_blocks > 0:
|
||||
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
||||
|
|
Loading…
Reference in New Issue
Block a user