forked from mrq/DL-Art-School
gen3 waveform
This commit is contained in:
parent
b19b0a74da
commit
ff8b0533ac
|
@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
|
|||
Downsample, Upsample, TimestepBlock
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from utils.util import checkpoint, print_network
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
|
@ -23,7 +23,7 @@ class ResBlock(TimestepBlock):
|
|||
out_channels=None,
|
||||
dims=2,
|
||||
kernel_size=3,
|
||||
efficient_config=True,
|
||||
efficient_config=False,
|
||||
use_scale_shift_norm=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -93,6 +93,8 @@ class ResBlock(TimestepBlock):
|
|||
|
||||
class StackedResidualBlock(TimestepBlock):
|
||||
def __init__(self, channels, emb_channels, dropout):
|
||||
super().__init__()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
|
@ -102,29 +104,30 @@ class StackedResidualBlock(TimestepBlock):
|
|||
)
|
||||
|
||||
gc = channels // 4
|
||||
super().__init__()
|
||||
self.initial_norm = nn.GroupNorm(num_groups=8, num_channels=channels)
|
||||
for i in range(5):
|
||||
out_channels = channels if i == 4 else gc
|
||||
self.add_module(
|
||||
f'conv{i + 1}',
|
||||
nn.Conv2d(channels + i * gc, out_channels, 3, 1, 1))
|
||||
self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=channels))
|
||||
nn.Conv1d(channels + i * gc, out_channels, 3, 1, 1))
|
||||
self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=out_channels))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
zero_module(self.conv5)
|
||||
self.drop = nn.Dropout(p=dropout)
|
||||
|
||||
def forawrd(self, x, emb):
|
||||
def forward(self, x, emb):
|
||||
return checkpoint(self.forward_, x, emb)
|
||||
|
||||
def forward_(self, x, emb):
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
emb_out = self.emb_layers(emb)
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
x0 = self.initial_norm(x) * (1 + scale) + shift
|
||||
x0 = self.initial_norm(x) * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1)
|
||||
x1 = self.lrelu(self.gn1(self.conv1(x0)))
|
||||
x2 = self.lrelu(self.gn2(self.conv2(torch.cat((x, x1), 1))))
|
||||
x3 = self.lrelu(self.gn3(self.conv3(torch.cat((x, x1, x2), 1))))
|
||||
x4 = self.lrelu(self.gn4(self.conv4(torch.cat((x, x1, x2, x3), 1))))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
x5 = self.drop(x5)
|
||||
|
||||
return x5 + x
|
||||
|
||||
|
@ -152,15 +155,14 @@ class DiffusionWaveformGen(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_channels,
|
||||
model_channels=512,
|
||||
in_channels=64,
|
||||
in_mel_channels=256,
|
||||
conditioning_dim_factor=8,
|
||||
conditioning_expansion=4,
|
||||
conditioning_dim_factor=4,
|
||||
out_channels=128, # mean and variance
|
||||
dropout=0,
|
||||
channel_mult= (1,1.5,2),
|
||||
num_res_blocks=(1,1,1),
|
||||
num_res_blocks=(1,1,0),
|
||||
token_conditioning_resolutions=(1,4),
|
||||
mid_resnet_depth=10,
|
||||
conv_resample=True,
|
||||
|
@ -168,9 +170,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
use_fp16=False,
|
||||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
time_embed_dim_multiplier=1,
|
||||
freeze_main_net=False,
|
||||
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
||||
use_scale_shift_norm=True,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
|
@ -197,7 +198,6 @@ class DiffusionWaveformGen(nn.Module):
|
|||
self.freeze_main_net = freeze_main_net
|
||||
self.in_mel_channels = in_mel_channels
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
down_kernel = 1 if efficient_convs else 3
|
||||
|
||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||
self.time_embed = nn.Sequential(
|
||||
|
@ -213,12 +213,6 @@ class DiffusionWaveformGen(nn.Module):
|
|||
# transformer network.
|
||||
self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
)
|
||||
self.conditioning_expansion = conditioning_expansion
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -249,7 +243,6 @@ class DiffusionWaveformGen(nn.Module):
|
|||
out_channels=int(mult * model_channels),
|
||||
dims=dims,
|
||||
kernel_size=kernel_size,
|
||||
efficient_config=efficient_convs,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
|
@ -262,7 +255,7 @@ class DiffusionWaveformGen(nn.Module):
|
|||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=3, pad=1
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -286,7 +279,6 @@ class DiffusionWaveformGen(nn.Module):
|
|||
out_channels=int(model_channels * mult),
|
||||
dims=dims,
|
||||
kernel_size=kernel_size,
|
||||
efficient_config=efficient_convs,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
|
@ -360,11 +352,6 @@ class DiffusionWaveformGen(nn.Module):
|
|||
else:
|
||||
code_emb = self.mel_converter(aligned_conditioning)
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
|
@ -374,7 +361,6 @@ class DiffusionWaveformGen(nn.Module):
|
|||
else:
|
||||
h = module(h, time_emb)
|
||||
hs.append(h)
|
||||
first = False
|
||||
h = self.middle_block(h, time_emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
|
@ -398,18 +384,11 @@ def register_unet_diffusion_waveform_gen3(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 1, 32868)
|
||||
aligned_sequence = torch.randn(2,120,220)
|
||||
clip = torch.randn(2, 64, 880)
|
||||
aligned_sequence = torch.randn(2,256,220)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionWaveformGen(128,
|
||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||
token_conditioning_resolutions=[1,4,16,64],
|
||||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
super_sampling=False,
|
||||
efficient_convs=False)
|
||||
model = DiffusionWaveformGen()
|
||||
# Test with sequence aligned conditioning
|
||||
o = model(clip, ts, aligned_sequence)
|
||||
print_network(model)
|
||||
|
Loading…
Reference in New Issue
Block a user