Support silu activation
This commit is contained in:
parent
67bf7f5219
commit
08b33c8e3a
|
@ -28,13 +28,13 @@ def eval_decorator(fn):
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
def __init__(self, chan, conv):
|
def __init__(self, chan, conv, activation):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
conv(chan, chan, 3, padding = 1),
|
conv(chan, chan, 3, padding = 1),
|
||||||
nn.ReLU(),
|
activation(),
|
||||||
conv(chan, chan, 3, padding = 1),
|
conv(chan, chan, 3, padding = 1),
|
||||||
nn.ReLU(),
|
activation(),
|
||||||
conv(chan, chan, 1)
|
conv(chan, chan, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ class DiscreteVAE(nn.Module):
|
||||||
kernel_size = 4,
|
kernel_size = 4,
|
||||||
use_transposed_convs = True,
|
use_transposed_convs = True,
|
||||||
encoder_norm = False,
|
encoder_norm = False,
|
||||||
|
activation = 'relu',
|
||||||
smooth_l1_loss = False,
|
smooth_l1_loss = False,
|
||||||
straight_through = False,
|
straight_through = False,
|
||||||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||||
|
@ -94,6 +95,14 @@ class DiscreteVAE(nn.Module):
|
||||||
if not use_transposed_convs:
|
if not use_transposed_convs:
|
||||||
conv_transpose = functools.partial(UpsampledConv, conv)
|
conv_transpose = functools.partial(UpsampledConv, conv)
|
||||||
|
|
||||||
|
if activation == 'relu':
|
||||||
|
act = nn.ReLU
|
||||||
|
elif activation == 'silu':
|
||||||
|
act = nn.SiLU
|
||||||
|
else:
|
||||||
|
assert NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
||||||
dec_chans = list(reversed(enc_chans))
|
dec_chans = list(reversed(enc_chans))
|
||||||
|
|
||||||
|
@ -109,14 +118,14 @@ class DiscreteVAE(nn.Module):
|
||||||
|
|
||||||
pad = (kernel_size - 1) // 2
|
pad = (kernel_size - 1) // 2
|
||||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
|
||||||
if encoder_norm:
|
if encoder_norm:
|
||||||
enc_layers.append(nn.GroupNorm(8, enc_out))
|
enc_layers.append(nn.GroupNorm(8, enc_out))
|
||||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
|
||||||
|
|
||||||
for _ in range(num_resnet_blocks):
|
for _ in range(num_resnet_blocks):
|
||||||
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
|
dec_layers.insert(0, ResBlock(dec_chans[1], conv, act))
|
||||||
enc_layers.append(ResBlock(enc_chans[-1], conv))
|
enc_layers.append(ResBlock(enc_chans[-1], conv, act))
|
||||||
|
|
||||||
if num_resnet_blocks > 0:
|
if num_resnet_blocks > 0:
|
||||||
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user