forked from mrq/DL-Art-School
gen3 waveform
This commit is contained in:
parent
b19b0a74da
commit
ff8b0533ac
|
@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
|
||||||
Downsample, Upsample, TimestepBlock
|
Downsample, Upsample, TimestepBlock
|
||||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint, print_network
|
||||||
|
|
||||||
|
|
||||||
def is_sequence(t):
|
def is_sequence(t):
|
||||||
|
@ -23,7 +23,7 @@ class ResBlock(TimestepBlock):
|
||||||
out_channels=None,
|
out_channels=None,
|
||||||
dims=2,
|
dims=2,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
efficient_config=True,
|
efficient_config=False,
|
||||||
use_scale_shift_norm=False,
|
use_scale_shift_norm=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -93,6 +93,8 @@ class ResBlock(TimestepBlock):
|
||||||
|
|
||||||
class StackedResidualBlock(TimestepBlock):
|
class StackedResidualBlock(TimestepBlock):
|
||||||
def __init__(self, channels, emb_channels, dropout):
|
def __init__(self, channels, emb_channels, dropout):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.emb_layers = nn.Sequential(
|
self.emb_layers = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(
|
linear(
|
||||||
|
@ -102,29 +104,30 @@ class StackedResidualBlock(TimestepBlock):
|
||||||
)
|
)
|
||||||
|
|
||||||
gc = channels // 4
|
gc = channels // 4
|
||||||
super().__init__()
|
|
||||||
self.initial_norm = nn.GroupNorm(num_groups=8, num_channels=channels)
|
self.initial_norm = nn.GroupNorm(num_groups=8, num_channels=channels)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
out_channels = channels if i == 4 else gc
|
out_channels = channels if i == 4 else gc
|
||||||
self.add_module(
|
self.add_module(
|
||||||
f'conv{i + 1}',
|
f'conv{i + 1}',
|
||||||
nn.Conv2d(channels + i * gc, out_channels, 3, 1, 1))
|
nn.Conv1d(channels + i * gc, out_channels, 3, 1, 1))
|
||||||
self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=channels))
|
self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=out_channels))
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
zero_module(self.conv5)
|
zero_module(self.conv5)
|
||||||
|
self.drop = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
def forawrd(self, x, emb):
|
def forward(self, x, emb):
|
||||||
return checkpoint(self.forward_, x, emb)
|
return checkpoint(self.forward_, x, emb)
|
||||||
|
|
||||||
def forward_(self, x, emb):
|
def forward_(self, x, emb):
|
||||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
emb_out = self.emb_layers(emb)
|
||||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||||
x0 = self.initial_norm(x) * (1 + scale) + shift
|
x0 = self.initial_norm(x) * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1)
|
||||||
x1 = self.lrelu(self.gn1(self.conv1(x0)))
|
x1 = self.lrelu(self.gn1(self.conv1(x0)))
|
||||||
x2 = self.lrelu(self.gn2(self.conv2(torch.cat((x, x1), 1))))
|
x2 = self.lrelu(self.gn2(self.conv2(torch.cat((x, x1), 1))))
|
||||||
x3 = self.lrelu(self.gn3(self.conv3(torch.cat((x, x1, x2), 1))))
|
x3 = self.lrelu(self.gn3(self.conv3(torch.cat((x, x1, x2), 1))))
|
||||||
x4 = self.lrelu(self.gn4(self.conv4(torch.cat((x, x1, x2, x3), 1))))
|
x4 = self.lrelu(self.gn4(self.conv4(torch.cat((x, x1, x2, x3), 1))))
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
x5 = self.drop(x5)
|
||||||
|
|
||||||
return x5 + x
|
return x5 + x
|
||||||
|
|
||||||
|
@ -152,15 +155,14 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_channels,
|
model_channels=512,
|
||||||
in_channels=64,
|
in_channels=64,
|
||||||
in_mel_channels=256,
|
in_mel_channels=256,
|
||||||
conditioning_dim_factor=8,
|
conditioning_dim_factor=4,
|
||||||
conditioning_expansion=4,
|
|
||||||
out_channels=128, # mean and variance
|
out_channels=128, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
channel_mult= (1,1.5,2),
|
channel_mult= (1,1.5,2),
|
||||||
num_res_blocks=(1,1,1),
|
num_res_blocks=(1,1,0),
|
||||||
token_conditioning_resolutions=(1,4),
|
token_conditioning_resolutions=(1,4),
|
||||||
mid_resnet_depth=10,
|
mid_resnet_depth=10,
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
|
@ -168,9 +170,8 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
time_embed_dim_multiplier=4,
|
time_embed_dim_multiplier=1,
|
||||||
freeze_main_net=False,
|
freeze_main_net=False,
|
||||||
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
|
||||||
use_scale_shift_norm=True,
|
use_scale_shift_norm=True,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
|
@ -197,7 +198,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.freeze_main_net = freeze_main_net
|
self.freeze_main_net = freeze_main_net
|
||||||
self.in_mel_channels = in_mel_channels
|
self.in_mel_channels = in_mel_channels
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
down_kernel = 1 if efficient_convs else 3
|
|
||||||
|
|
||||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
|
@ -213,12 +213,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
# transformer network.
|
# transformer network.
|
||||||
self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
|
self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
|
||||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
|
||||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
|
||||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
|
||||||
)
|
|
||||||
self.conditioning_expansion = conditioning_expansion
|
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -249,7 +243,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
out_channels=int(mult * model_channels),
|
out_channels=int(mult * model_channels),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
efficient_config=efficient_convs,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -262,7 +255,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.input_blocks.append(
|
self.input_blocks.append(
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
Downsample(
|
Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
|
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=3, pad=1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -286,7 +279,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
out_channels=int(model_channels * mult),
|
out_channels=int(model_channels * mult),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
efficient_config=efficient_convs,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -360,11 +352,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
else:
|
else:
|
||||||
code_emb = self.mel_converter(aligned_conditioning)
|
code_emb = self.mel_converter(aligned_conditioning)
|
||||||
|
|
||||||
# Everything after this comment is timestep dependent.
|
|
||||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
|
||||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
|
||||||
|
|
||||||
first = True
|
|
||||||
time_emb = time_emb.float()
|
time_emb = time_emb.float()
|
||||||
h = x
|
h = x
|
||||||
for k, module in enumerate(self.input_blocks):
|
for k, module in enumerate(self.input_blocks):
|
||||||
|
@ -374,7 +361,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
else:
|
else:
|
||||||
h = module(h, time_emb)
|
h = module(h, time_emb)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
first = False
|
|
||||||
h = self.middle_block(h, time_emb)
|
h = self.middle_block(h, time_emb)
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
@ -398,18 +384,11 @@ def register_unet_diffusion_waveform_gen3(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 1, 32868)
|
clip = torch.randn(2, 64, 880)
|
||||||
aligned_sequence = torch.randn(2,120,220)
|
aligned_sequence = torch.randn(2,256,220)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = DiffusionWaveformGen(128,
|
model = DiffusionWaveformGen()
|
||||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
|
||||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
|
||||||
token_conditioning_resolutions=[1,4,16,64],
|
|
||||||
kernel_size=3,
|
|
||||||
scale_factor=2,
|
|
||||||
time_embed_dim_multiplier=4,
|
|
||||||
super_sampling=False,
|
|
||||||
efficient_convs=False)
|
|
||||||
# Test with sequence aligned conditioning
|
# Test with sequence aligned conditioning
|
||||||
o = model(clip, ts, aligned_sequence)
|
o = model(clip, ts, aligned_sequence)
|
||||||
|
print_network(model)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user