From 08b33c8e3a38b230d644680ef46d3ab20e4b1179 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 25 Aug 2021 09:03:14 -0600 Subject: [PATCH] Support silu activation --- codes/models/gpt_voice/lucidrains_dvae.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 9b171619..9b89ffaf 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -28,13 +28,13 @@ def eval_decorator(fn): class ResBlock(nn.Module): - def __init__(self, chan, conv): + def __init__(self, chan, conv, activation): super().__init__() self.net = nn.Sequential( conv(chan, chan, 3, padding = 1), - nn.ReLU(), + activation(), conv(chan, chan, 3, padding = 1), - nn.ReLU(), + activation(), conv(chan, chan, 1) ) @@ -69,6 +69,7 @@ class DiscreteVAE(nn.Module): kernel_size = 4, use_transposed_convs = True, encoder_norm = False, + activation = 'relu', smooth_l1_loss = False, straight_through = False, normalization = None, # ((0.5,) * 3, (0.5,) * 3), @@ -94,6 +95,14 @@ class DiscreteVAE(nn.Module): if not use_transposed_convs: conv_transpose = functools.partial(UpsampledConv, conv) + if activation == 'relu': + act = nn.ReLU + elif activation == 'silu': + act = nn.SiLU + else: + assert NotImplementedError() + + enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)] dec_chans = list(reversed(enc_chans)) @@ -109,14 +118,14 @@ class DiscreteVAE(nn.Module): pad = (kernel_size - 1) // 2 for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): - enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU())) + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act())) if encoder_norm: enc_layers.append(nn.GroupNorm(8, enc_out)) - dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU())) + dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act())) for _ in range(num_resnet_blocks): - dec_layers.insert(0, ResBlock(dec_chans[1], conv)) - enc_layers.append(ResBlock(enc_chans[-1], conv)) + dec_layers.insert(0, ResBlock(dec_chans[1], conv, act)) + enc_layers.append(ResBlock(enc_chans[-1], conv, act)) if num_resnet_blocks > 0: dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))