diff --git a/codes/models/gpt_voice/unet_diffusion_tts9.py b/codes/models/gpt_voice/unet_diffusion_tts9.py index 3d03aef3..ceee0e4e 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts9.py +++ b/codes/models/gpt_voice/unet_diffusion_tts9.py @@ -34,6 +34,7 @@ class ResBlock(TimestepBlock): out_channels=None, dims=2, kernel_size=3, + efficient_config=True, ): super().__init__() self.channels = channels @@ -41,11 +42,13 @@ class ResBlock(TimestepBlock): self.dropout = dropout self.out_channels = out_channels or channels padding = {1: 0, 3: 1, 5: 2}[kernel_size] + eff_kernel = 1 if efficient_config else 3 + eff_padding = 0 if efficient_config else 1 self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 1, padding=0), + conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), ) self.emb_layers = nn.Sequential( @@ -67,7 +70,7 @@ class ResBlock(TimestepBlock): if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): """ @@ -146,6 +149,7 @@ class DiffusionTts(nn.Module): kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, + efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Parameters for super-sampling. @@ -178,6 +182,7 @@ class DiffusionTts(nn.Module): self.jit_enabled = jit_enabled self.jit_forward = None padding = 1 if kernel_size == 3 else 2 + down_kernel = 1 if efficient_convs else 3 time_embed_dim = model_channels * time_embed_dim_multiplier self.time_embed = nn.Sequential( @@ -251,6 +256,7 @@ class DiffusionTts(nn.Module): out_channels=int(mult * model_channels), dims=dims, kernel_size=kernel_size, + efficient_config=efficient_convs ) ] ch = int(mult * model_channels) @@ -270,7 +276,7 @@ class DiffusionTts(nn.Module): self.input_blocks.append( TimestepEmbedSequential( Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=1, pad=0 + ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1 ) ) ) @@ -286,6 +292,7 @@ class DiffusionTts(nn.Module): dropout, dims=dims, kernel_size=kernel_size, + efficient_config=efficient_convs, ), AttentionBlock( ch, @@ -298,6 +305,7 @@ class DiffusionTts(nn.Module): dropout, dims=dims, kernel_size=kernel_size, + efficient_config=efficient_convs, ), ) self._feature_size += ch @@ -314,6 +322,7 @@ class DiffusionTts(nn.Module): out_channels=int(model_channels * mult), dims=dims, kernel_size=kernel_size, + efficient_config=efficient_convs, ) ] ch = int(model_channels * mult) @@ -466,7 +475,8 @@ if __name__ == '__main__': kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, - super_sampling=False) + super_sampling=False, + efficient_convs=False) # Test with latent aligned conditioning o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning