add scale_shift_norm back to tts9

This commit is contained in:
James Betker 2022-03-12 20:42:13 -07:00
parent 9bbbe26012
commit 8f130e2b3f

View File

@ -35,12 +35,14 @@ class ResBlock(TimestepBlock):
dims=2, dims=2,
kernel_size=3, kernel_size=3,
efficient_config=True, efficient_config=True,
use_scale_shift_norm=False,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.emb_channels = emb_channels self.emb_channels = emb_channels
self.dropout = dropout self.dropout = dropout
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_scale_shift_norm = use_scale_shift_norm
padding = {1: 0, 3: 1, 5: 2}[kernel_size] padding = {1: 0, 3: 1, 5: 2}[kernel_size]
eff_kernel = 1 if efficient_config else 3 eff_kernel = 1 if efficient_config else 3
eff_padding = 0 if efficient_config else 1 eff_padding = 0 if efficient_config else 1
@ -55,7 +57,7 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
self.out_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
@ -89,8 +91,14 @@ class ResBlock(TimestepBlock):
emb_out = self.emb_layers(emb).type(h.dtype) emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape): while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None] emb_out = emb_out[..., None]
h = h + emb_out if self.use_scale_shift_norm:
h = self.out_layers(h) out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h return self.skip_connection(x) + h
class DiffusionTts(nn.Module): class DiffusionTts(nn.Module):
@ -150,6 +158,7 @@ class DiffusionTts(nn.Module):
scale_factor=2, scale_factor=2,
time_embed_dim_multiplier=4, time_embed_dim_multiplier=4,
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
use_scale_shift_norm=True,
# Parameters for regularization. # Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for super-sampling. # Parameters for super-sampling.
@ -217,11 +226,11 @@ class DiffusionTts(nn.Module):
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1) self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential( self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1), ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1), ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1), ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
) )
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
@ -253,7 +262,8 @@ class DiffusionTts(nn.Module):
out_channels=int(mult * model_channels), out_channels=int(mult * model_channels),
dims=dims, dims=dims,
kernel_size=kernel_size, kernel_size=kernel_size,
efficient_config=efficient_convs efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
) )
] ]
ch = int(mult * model_channels) ch = int(mult * model_channels)
@ -290,6 +300,7 @@ class DiffusionTts(nn.Module):
dims=dims, dims=dims,
kernel_size=kernel_size, kernel_size=kernel_size,
efficient_config=efficient_convs, efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
), ),
AttentionBlock( AttentionBlock(
ch, ch,
@ -303,6 +314,7 @@ class DiffusionTts(nn.Module):
dims=dims, dims=dims,
kernel_size=kernel_size, kernel_size=kernel_size,
efficient_config=efficient_convs, efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -320,6 +332,7 @@ class DiffusionTts(nn.Module):
dims=dims, dims=dims,
kernel_size=kernel_size, kernel_size=kernel_size,
efficient_config=efficient_convs, efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
) )
] ]
ch = int(model_channels * mult) ch = int(model_channels * mult)