gen3 waveform

pull/9/head
James Betker 2022-06-19 19:23:48 +07:00
parent b19b0a74da
commit ff8b0533ac
1 changed files with 20 additions and 41 deletions

@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model
from utils.util import checkpoint
from utils.util import checkpoint, print_network
def is_sequence(t):
@ -23,7 +23,7 @@ class ResBlock(TimestepBlock):
out_channels=None,
dims=2,
kernel_size=3,
efficient_config=True,
efficient_config=False,
use_scale_shift_norm=False,
):
super().__init__()
@ -93,6 +93,8 @@ class ResBlock(TimestepBlock):
class StackedResidualBlock(TimestepBlock):
def __init__(self, channels, emb_channels, dropout):
super().__init__()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
@ -102,29 +104,30 @@ class StackedResidualBlock(TimestepBlock):
)
gc = channels // 4
super().__init__()
self.initial_norm = nn.GroupNorm(num_groups=8, num_channels=channels)
for i in range(5):
out_channels = channels if i == 4 else gc
self.add_module(
f'conv{i + 1}',
nn.Conv2d(channels + i * gc, out_channels, 3, 1, 1))
self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=channels))
nn.Conv1d(channels + i * gc, out_channels, 3, 1, 1))
self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=out_channels))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
zero_module(self.conv5)
self.drop = nn.Dropout(p=dropout)
def forawrd(self, x, emb):
def forward(self, x, emb):
return checkpoint(self.forward_, x, emb)
def forward_(self, x, emb):
emb_out = self.emb_layers(emb).type(h.dtype)
emb_out = self.emb_layers(emb)
scale, shift = torch.chunk(emb_out, 2, dim=1)
x0 = self.initial_norm(x) * (1 + scale) + shift
x0 = self.initial_norm(x) * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1)
x1 = self.lrelu(self.gn1(self.conv1(x0)))
x2 = self.lrelu(self.gn2(self.conv2(torch.cat((x, x1), 1))))
x3 = self.lrelu(self.gn3(self.conv3(torch.cat((x, x1, x2), 1))))
x4 = self.lrelu(self.gn4(self.conv4(torch.cat((x, x1, x2, x3), 1))))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
x5 = self.drop(x5)
return x5 + x
@ -152,15 +155,14 @@ class DiffusionWaveformGen(nn.Module):
def __init__(
self,
model_channels,
model_channels=512,
in_channels=64,
in_mel_channels=256,
conditioning_dim_factor=8,
conditioning_expansion=4,
conditioning_dim_factor=4,
out_channels=128, # mean and variance
dropout=0,
channel_mult= (1,1.5,2),
num_res_blocks=(1,1,1),
num_res_blocks=(1,1,0),
token_conditioning_resolutions=(1,4),
mid_resnet_depth=10,
conv_resample=True,
@ -168,9 +170,8 @@ class DiffusionWaveformGen(nn.Module):
use_fp16=False,
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
time_embed_dim_multiplier=1,
freeze_main_net=False,
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
use_scale_shift_norm=True,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
@ -197,7 +198,6 @@ class DiffusionWaveformGen(nn.Module):
self.freeze_main_net = freeze_main_net
self.in_mel_channels = in_mel_channels
padding = 1 if kernel_size == 3 else 2
down_kernel = 1 if efficient_convs else 3
time_embed_dim = model_channels * time_embed_dim_multiplier
self.time_embed = nn.Sequential(
@ -213,12 +213,6 @@ class DiffusionWaveformGen(nn.Module):
# transformer network.
self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
)
self.conditioning_expansion = conditioning_expansion
self.input_blocks = nn.ModuleList(
[
@ -249,7 +243,6 @@ class DiffusionWaveformGen(nn.Module):
out_channels=int(mult * model_channels),
dims=dims,
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
)
]
@ -262,7 +255,7 @@ class DiffusionWaveformGen(nn.Module):
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=3, pad=1
)
)
)
@ -286,7 +279,6 @@ class DiffusionWaveformGen(nn.Module):
out_channels=int(model_channels * mult),
dims=dims,
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
)
]
@ -360,11 +352,6 @@ class DiffusionWaveformGen(nn.Module):
else:
code_emb = self.mel_converter(aligned_conditioning)
# Everything after this comment is timestep dependent.
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
first = True
time_emb = time_emb.float()
h = x
for k, module in enumerate(self.input_blocks):
@ -374,7 +361,6 @@ class DiffusionWaveformGen(nn.Module):
else:
h = module(h, time_emb)
hs.append(h)
first = False
h = self.middle_block(h, time_emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
@ -398,18 +384,11 @@ def register_unet_diffusion_waveform_gen3(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 1, 32868)
aligned_sequence = torch.randn(2,120,220)
clip = torch.randn(2, 64, 880)
aligned_sequence = torch.randn(2,256,220)
ts = torch.LongTensor([600, 600])
model = DiffusionWaveformGen(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
token_conditioning_resolutions=[1,4,16,64],
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
super_sampling=False,
efficient_convs=False)
model = DiffusionWaveformGen()
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence)
print_network(model)