forked from mrq/DL-Art-School
add efficient config to tts9
This commit is contained in:
parent
896accb71f
commit
0523777ff7
|
@ -34,6 +34,7 @@ class ResBlock(TimestepBlock):
|
||||||
out_channels=None,
|
out_channels=None,
|
||||||
dims=2,
|
dims=2,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
|
efficient_config=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
@ -41,11 +42,13 @@ class ResBlock(TimestepBlock):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
||||||
|
eff_kernel = 1 if efficient_config else 3
|
||||||
|
eff_padding = 0 if efficient_config else 1
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels),
|
normalization(channels),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 1, padding=0),
|
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.emb_layers = nn.Sequential(
|
self.emb_layers = nn.Sequential(
|
||||||
|
@ -67,7 +70,7 @@ class ResBlock(TimestepBlock):
|
||||||
if self.out_channels == channels:
|
if self.out_channels == channels:
|
||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""
|
"""
|
||||||
|
@ -146,6 +149,7 @@ class DiffusionTts(nn.Module):
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
time_embed_dim_multiplier=4,
|
time_embed_dim_multiplier=4,
|
||||||
|
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
# Parameters for super-sampling.
|
# Parameters for super-sampling.
|
||||||
|
@ -178,6 +182,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.jit_enabled = jit_enabled
|
self.jit_enabled = jit_enabled
|
||||||
self.jit_forward = None
|
self.jit_forward = None
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
down_kernel = 1 if efficient_convs else 3
|
||||||
|
|
||||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
|
@ -251,6 +256,7 @@ class DiffusionTts(nn.Module):
|
||||||
out_channels=int(mult * model_channels),
|
out_channels=int(mult * model_channels),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = int(mult * model_channels)
|
ch = int(mult * model_channels)
|
||||||
|
@ -270,7 +276,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.input_blocks.append(
|
self.input_blocks.append(
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
Downsample(
|
Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=1, pad=0
|
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -286,6 +292,7 @@ class DiffusionTts(nn.Module):
|
||||||
dropout,
|
dropout,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -298,6 +305,7 @@ class DiffusionTts(nn.Module):
|
||||||
dropout,
|
dropout,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
@ -314,6 +322,7 @@ class DiffusionTts(nn.Module):
|
||||||
out_channels=int(model_channels * mult),
|
out_channels=int(model_channels * mult),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = int(model_channels * mult)
|
ch = int(model_channels * mult)
|
||||||
|
@ -466,7 +475,8 @@ if __name__ == '__main__':
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
time_embed_dim_multiplier=4,
|
time_embed_dim_multiplier=4,
|
||||||
super_sampling=False)
|
super_sampling=False,
|
||||||
|
efficient_convs=False)
|
||||||
# Test with latent aligned conditioning
|
# Test with latent aligned conditioning
|
||||||
o = model(clip, ts, aligned_latent, cond)
|
o = model(clip, ts, aligned_latent, cond)
|
||||||
# Test with sequence aligned conditioning
|
# Test with sequence aligned conditioning
|
||||||
|
|
Loading…
Reference in New Issue
Block a user