forked from mrq/DL-Art-School
add scale_shift_norm back to tts9
This commit is contained in:
parent
9bbbe26012
commit
8f130e2b3f
|
@ -35,12 +35,14 @@ class ResBlock(TimestepBlock):
|
||||||
dims=2,
|
dims=2,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
efficient_config=True,
|
efficient_config=True,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.emb_channels = emb_channels
|
self.emb_channels = emb_channels
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
||||||
eff_kernel = 1 if efficient_config else 3
|
eff_kernel = 1 if efficient_config else 3
|
||||||
eff_padding = 0 if efficient_config else 1
|
eff_padding = 0 if efficient_config else 1
|
||||||
|
@ -55,7 +57,7 @@ class ResBlock(TimestepBlock):
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(
|
linear(
|
||||||
emb_channels,
|
emb_channels,
|
||||||
self.out_channels,
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
|
@ -89,8 +91,14 @@ class ResBlock(TimestepBlock):
|
||||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||||
while len(emb_out.shape) < len(h.shape):
|
while len(emb_out.shape) < len(h.shape):
|
||||||
emb_out = emb_out[..., None]
|
emb_out = emb_out[..., None]
|
||||||
h = h + emb_out
|
if self.use_scale_shift_norm:
|
||||||
h = self.out_layers(h)
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||||
|
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||||
|
h = out_norm(h) * (1 + scale) + shift
|
||||||
|
h = out_rest(h)
|
||||||
|
else:
|
||||||
|
h = h + emb_out
|
||||||
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
class DiffusionTts(nn.Module):
|
class DiffusionTts(nn.Module):
|
||||||
|
@ -150,6 +158,7 @@ class DiffusionTts(nn.Module):
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
time_embed_dim_multiplier=4,
|
time_embed_dim_multiplier=4,
|
||||||
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
||||||
|
use_scale_shift_norm=True,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
# Parameters for super-sampling.
|
# Parameters for super-sampling.
|
||||||
|
@ -217,11 +226,11 @@ class DiffusionTts(nn.Module):
|
||||||
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
|
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1),
|
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1),
|
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1),
|
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
|
@ -253,7 +262,8 @@ class DiffusionTts(nn.Module):
|
||||||
out_channels=int(mult * model_channels),
|
out_channels=int(mult * model_channels),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
efficient_config=efficient_convs
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = int(mult * model_channels)
|
ch = int(mult * model_channels)
|
||||||
|
@ -290,6 +300,7 @@ class DiffusionTts(nn.Module):
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
efficient_config=efficient_convs,
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -303,6 +314,7 @@ class DiffusionTts(nn.Module):
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
efficient_config=efficient_convs,
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
@ -320,6 +332,7 @@ class DiffusionTts(nn.Module):
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
efficient_config=efficient_convs,
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = int(model_channels * mult)
|
ch = int(model_channels * mult)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user