Which is basically a autoencoder with a giant diffusion appendage attached
This commit is contained in:
James Betker 2022-04-20 21:37:34 -06:00
parent b1c2c48720
commit 9df85c902e

View File

@ -144,25 +144,41 @@ class ResBlockSimple(nn.Module):
return self.skip_connection(x) + h
class StructuralProcessor(nn.Module):
class AudioVAE(nn.Module):
def __init__(self, channels, dropout):
super().__init__()
# 256,128,64,32,16,8,4,2,1
level_resblocks = [3, 3, 2, 2, 2,1,1,1]
level_ch_div = [1, 1, 2, 4, 4,8,8,16]
# 1, 4, 16, 64, 256
level_resblocks = [1, 1, 2, 2, 2]
level_ch_mult = [1, 2, 4, 6, 8]
levels = []
lastdiv = 1
for resblks, chdiv in zip(level_resblocks, level_ch_div):
levels.append(nn.Sequential(*([nn.Conv1d(channels//lastdiv, channels//chdiv, 1)] +
[ResBlockSimple(channels//chdiv, dropout) for _ in range(resblks)])))
for i, (resblks, chdiv) in enumerate(zip(level_resblocks, level_ch_mult)):
blocks = [ResBlockSimple(channels*chdiv, dropout=dropout, kernel_size=5) for _ in range(resblks)]
if i != len(level_ch_mult)-1:
blocks.append(nn.Conv1d(channels*chdiv, channels*level_ch_mult[i+1], kernel_size=5, padding=2, stride=4))
levels.append(nn.Sequential(*blocks))
self.down_levels = nn.ModuleList(levels)
levels = []
lastdiv = None
for resblks, chdiv in reversed(list(zip(level_resblocks, level_ch_mult))):
if lastdiv is not None:
blocks = [nn.Conv1d(channels*lastdiv, channels*chdiv, kernel_size=5, padding=2)]
else:
blocks = []
blocks.extend([ResBlockSimple(channels*chdiv, dropout=dropout, kernel_size=5) for _ in range(resblks)])
levels.append(nn.Sequential(*blocks))
lastdiv = chdiv
self.levels = nn.ModuleList(levels)
self.up_levels = nn.ModuleList(levels)
def forward(self, x):
h = x
for level in self.levels:
for level in self.down_levels:
h = level(h)
h = F.interpolate(h, scale_factor=2, mode='linear')
for k, level in enumerate(self.up_levels):
h = level(h)
if k != len(self.up_levels)-1:
h = F.interpolate(h, scale_factor=4, mode='linear')
return h
@ -178,20 +194,10 @@ class DiffusionTts(nn.Module):
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
@ -202,8 +208,6 @@ class DiffusionTts(nn.Module):
self,
model_channels,
in_channels=1,
in_mel_channels=120,
conditioning_dim_factor=8,
out_channels=2, # mean and variance
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
@ -211,13 +215,9 @@ class DiffusionTts(nn.Module):
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
attention_resolutions=(512,1024,2048),
conv_resample=True,
dims=1,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
@ -229,24 +229,16 @@ class DiffusionTts(nn.Module):
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.dims = dims
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.alignment_size = 2 ** (len(channel_mult)+1)
self.in_mel_channels = in_mel_channels
self.alignment_size = max(2 ** (len(channel_mult)+1), 256)
padding = 1 if kernel_size == 3 else 2
down_kernel = 1 if efficient_convs else 3
@ -257,18 +249,17 @@ class DiffusionTts(nn.Module):
linear(time_embed_dim, time_embed_dim),
)
conditioning_dim = model_channels * conditioning_dim_factor
self.structural_cond_input = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_mel_channels,1))
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.structural_processor = StructuralProcessor(conditioning_dim, dropout)
self.surrogate_head = nn.Conv1d(conditioning_dim//16, in_channels, 1)
self.structural_cond_input = nn.Conv1d(in_channels, model_channels, kernel_size=5, padding=2)
self.aligned_latent_padding_embedding = nn.Parameter(torch.zeros(1,in_channels,1))
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.structural_processor = AudioVAE(model_channels, dropout)
self.surrogate_head = nn.Conv1d(model_channels, in_channels, 1)
self.input_block = conv_nd(dims, in_channels, model_channels//2, kernel_size, padding=padding)
self.input_block = conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, model_channels, model_channels, kernel_size, padding=padding)
conv_nd(dims, model_channels*2, model_channels, 1)
)
]
)
@ -292,14 +283,6 @@ class DiffusionTts(nn.Module):
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
@ -327,20 +310,6 @@ class DiffusionTts(nn.Module):
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
@ -361,14 +330,6 @@ class DiffusionTts(nn.Module):
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
)
)
if level and i == num_blocks:
out_ch = ch
layers.append(
@ -403,9 +364,6 @@ class DiffusionTts(nn.Module):
}
return groups
def is_latent(self, t):
return t.shape[1] != self.in_mel_channels
def fix_alignment(self, x, aligned_conditioning):
"""
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
@ -415,36 +373,26 @@ class DiffusionTts(nn.Module):
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
# Also fix aligned_latent, which is aligned to x.
if self.is_latent(aligned_conditioning):
aligned_conditioning = torch.cat([aligned_conditioning,
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
else:
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
def forward(self, x, timesteps, conditioning, conditioning_free=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning: should just be the truth value. produces a latent through an autoencoder, then uses diffusion to decode that latent.
at inference, only the latent is passed in.
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
# Shuffle aligned_latent to BxCxS format
if self.is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
orig_x_shape = x.shape[-1]
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
x, aligned_conditioning = self.fix_alignment(x, conditioning)
with autocast(x.device.type, enabled=self.enable_fp16):
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
if conditioning_free:
@ -456,10 +404,12 @@ class DiffusionTts(nn.Module):
code_emb = F.interpolate(code_emb, size=(x.shape[-1],), mode='linear')
surrogate = self.surrogate_head(code_emb)
# Everything after this comment is timestep dependent.
x = self.input_block(x)
x = torch.cat([x, code_emb], dim=1)
# Everything after this comment is timestep dependent.
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
time_emb = time_emb.float()
h = x
for k, module in enumerate(self.input_blocks):
@ -493,13 +443,11 @@ def register_unet_diffusion_waveform_gen2(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 1, 32868)
aligned_sequence = torch.randn(2,120,128)
aligned_sequence = torch.randn(2,1,32868)
ts = torch.LongTensor([600, 600])
model = DiffusionTts(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
attention_resolutions=[],
num_heads=8,
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,