diff --git a/codes/models/gpt_voice/unet_diffusion_tts9.py b/codes/models/gpt_voice/unet_diffusion_tts9.py index f24fc5c5..b9520638 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts9.py +++ b/codes/models/gpt_voice/unet_diffusion_tts9.py @@ -35,12 +35,14 @@ class ResBlock(TimestepBlock): dims=2, kernel_size=3, efficient_config=True, + use_scale_shift_norm=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels + self.use_scale_shift_norm = use_scale_shift_norm padding = {1: 0, 3: 1, 5: 2}[kernel_size] eff_kernel = 1 if efficient_config else 3 eff_padding = 0 if efficient_config else 1 @@ -55,7 +57,7 @@ class ResBlock(TimestepBlock): nn.SiLU(), linear( emb_channels, - self.out_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( @@ -89,8 +91,14 @@ class ResBlock(TimestepBlock): emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] - h = h + emb_out - h = self.out_layers(h) + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) return self.skip_connection(x) + h class DiffusionTts(nn.Module): @@ -150,6 +158,7 @@ class DiffusionTts(nn.Module): scale_factor=2, time_embed_dim_multiplier=4, efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. + use_scale_shift_norm=True, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Parameters for super-sampling. @@ -217,11 +226,11 @@ class DiffusionTts(nn.Module): self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1) self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), ) self.input_blocks = nn.ModuleList( @@ -253,7 +262,8 @@ class DiffusionTts(nn.Module): out_channels=int(mult * model_channels), dims=dims, kernel_size=kernel_size, - efficient_config=efficient_convs + efficient_config=efficient_convs, + use_scale_shift_norm=use_scale_shift_norm, ) ] ch = int(mult * model_channels) @@ -290,6 +300,7 @@ class DiffusionTts(nn.Module): dims=dims, kernel_size=kernel_size, efficient_config=efficient_convs, + use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock( ch, @@ -303,6 +314,7 @@ class DiffusionTts(nn.Module): dims=dims, kernel_size=kernel_size, efficient_config=efficient_convs, + use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch @@ -320,6 +332,7 @@ class DiffusionTts(nn.Module): dims=dims, kernel_size=kernel_size, efficient_config=efficient_convs, + use_scale_shift_norm=use_scale_shift_norm, ) ] ch = int(model_channels * mult)