forked from mrq/DL-Art-School
New gen2
Which is basically a autoencoder with a giant diffusion appendage attached
This commit is contained in:
parent
b1c2c48720
commit
9df85c902e
|
@ -144,25 +144,41 @@ class ResBlockSimple(nn.Module):
|
|||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class StructuralProcessor(nn.Module):
|
||||
class AudioVAE(nn.Module):
|
||||
def __init__(self, channels, dropout):
|
||||
super().__init__()
|
||||
# 256,128,64,32,16,8,4,2,1
|
||||
level_resblocks = [3, 3, 2, 2, 2,1,1,1]
|
||||
level_ch_div = [1, 1, 2, 4, 4,8,8,16]
|
||||
# 1, 4, 16, 64, 256
|
||||
level_resblocks = [1, 1, 2, 2, 2]
|
||||
level_ch_mult = [1, 2, 4, 6, 8]
|
||||
levels = []
|
||||
lastdiv = 1
|
||||
for resblks, chdiv in zip(level_resblocks, level_ch_div):
|
||||
levels.append(nn.Sequential(*([nn.Conv1d(channels//lastdiv, channels//chdiv, 1)] +
|
||||
[ResBlockSimple(channels//chdiv, dropout) for _ in range(resblks)])))
|
||||
for i, (resblks, chdiv) in enumerate(zip(level_resblocks, level_ch_mult)):
|
||||
blocks = [ResBlockSimple(channels*chdiv, dropout=dropout, kernel_size=5) for _ in range(resblks)]
|
||||
if i != len(level_ch_mult)-1:
|
||||
blocks.append(nn.Conv1d(channels*chdiv, channels*level_ch_mult[i+1], kernel_size=5, padding=2, stride=4))
|
||||
levels.append(nn.Sequential(*blocks))
|
||||
self.down_levels = nn.ModuleList(levels)
|
||||
|
||||
levels = []
|
||||
lastdiv = None
|
||||
for resblks, chdiv in reversed(list(zip(level_resblocks, level_ch_mult))):
|
||||
if lastdiv is not None:
|
||||
blocks = [nn.Conv1d(channels*lastdiv, channels*chdiv, kernel_size=5, padding=2)]
|
||||
else:
|
||||
blocks = []
|
||||
blocks.extend([ResBlockSimple(channels*chdiv, dropout=dropout, kernel_size=5) for _ in range(resblks)])
|
||||
levels.append(nn.Sequential(*blocks))
|
||||
lastdiv = chdiv
|
||||
self.levels = nn.ModuleList(levels)
|
||||
self.up_levels = nn.ModuleList(levels)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
for level in self.levels:
|
||||
for level in self.down_levels:
|
||||
h = level(h)
|
||||
h = F.interpolate(h, scale_factor=2, mode='linear')
|
||||
|
||||
for k, level in enumerate(self.up_levels):
|
||||
h = level(h)
|
||||
if k != len(self.up_levels)-1:
|
||||
h = F.interpolate(h, scale_factor=4, mode='linear')
|
||||
return h
|
||||
|
||||
|
||||
|
@ -178,20 +194,10 @@ class DiffusionTts(nn.Module):
|
|||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
|
@ -202,8 +208,6 @@ class DiffusionTts(nn.Module):
|
|||
self,
|
||||
model_channels,
|
||||
in_channels=1,
|
||||
in_mel_channels=120,
|
||||
conditioning_dim_factor=8,
|
||||
out_channels=2, # mean and variance
|
||||
dropout=0,
|
||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||
|
@ -211,13 +215,9 @@ class DiffusionTts(nn.Module):
|
|||
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
||||
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
||||
attention_resolutions=(512,1024,2048),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
|
@ -229,24 +229,16 @@ class DiffusionTts(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.dims = dims
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.enable_fp16 = use_fp16
|
||||
self.alignment_size = 2 ** (len(channel_mult)+1)
|
||||
self.in_mel_channels = in_mel_channels
|
||||
self.alignment_size = max(2 ** (len(channel_mult)+1), 256)
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
down_kernel = 1 if efficient_convs else 3
|
||||
|
||||
|
@ -257,18 +249,17 @@ class DiffusionTts(nn.Module):
|
|||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
conditioning_dim = model_channels * conditioning_dim_factor
|
||||
self.structural_cond_input = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
|
||||
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_mel_channels,1))
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||
self.structural_processor = StructuralProcessor(conditioning_dim, dropout)
|
||||
self.surrogate_head = nn.Conv1d(conditioning_dim//16, in_channels, 1)
|
||||
self.structural_cond_input = nn.Conv1d(in_channels, model_channels, kernel_size=5, padding=2)
|
||||
self.aligned_latent_padding_embedding = nn.Parameter(torch.zeros(1,in_channels,1))
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
||||
self.structural_processor = AudioVAE(model_channels, dropout)
|
||||
self.surrogate_head = nn.Conv1d(model_channels, in_channels, 1)
|
||||
|
||||
self.input_block = conv_nd(dims, in_channels, model_channels//2, kernel_size, padding=padding)
|
||||
self.input_block = conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, model_channels, model_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, model_channels*2, model_channels, 1)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -292,14 +283,6 @@ class DiffusionTts(nn.Module):
|
|||
)
|
||||
]
|
||||
ch = int(mult * model_channels)
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
|
@ -327,20 +310,6 @@ class DiffusionTts(nn.Module):
|
|||
efficient_config=efficient_convs,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
kernel_size=kernel_size,
|
||||
efficient_config=efficient_convs,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
|
@ -361,14 +330,6 @@ class DiffusionTts(nn.Module):
|
|||
)
|
||||
]
|
||||
ch = int(model_channels * mult)
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=num_head_channels,
|
||||
)
|
||||
)
|
||||
if level and i == num_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
|
@ -403,9 +364,6 @@ class DiffusionTts(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def is_latent(self, t):
|
||||
return t.shape[1] != self.in_mel_channels
|
||||
|
||||
def fix_alignment(self, x, aligned_conditioning):
|
||||
"""
|
||||
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
|
||||
|
@ -415,36 +373,26 @@ class DiffusionTts(nn.Module):
|
|||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
# Also fix aligned_latent, which is aligned to x.
|
||||
if self.is_latent(aligned_conditioning):
|
||||
aligned_conditioning = torch.cat([aligned_conditioning,
|
||||
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
||||
else:
|
||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
|
||||
def forward(self, x, timesteps, conditioning, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param conditioning: should just be the truth value. produces a latent through an autoencoder, then uses diffusion to decode that latent.
|
||||
at inference, only the latent is passed in.
|
||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
if self.is_latent(aligned_conditioning):
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
|
||||
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||
orig_x_shape = x.shape[-1]
|
||||
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
||||
x, aligned_conditioning = self.fix_alignment(x, conditioning)
|
||||
|
||||
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
||||
if conditioning_free:
|
||||
|
@ -456,10 +404,12 @@ class DiffusionTts(nn.Module):
|
|||
code_emb = F.interpolate(code_emb, size=(x.shape[-1],), mode='linear')
|
||||
surrogate = self.surrogate_head(code_emb)
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
x = self.input_block(x)
|
||||
x = torch.cat([x, code_emb], dim=1)
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
|
@ -493,13 +443,11 @@ def register_unet_diffusion_waveform_gen2(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 1, 32868)
|
||||
aligned_sequence = torch.randn(2,120,128)
|
||||
aligned_sequence = torch.randn(2,1,32868)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionTts(128,
|
||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||
attention_resolutions=[],
|
||||
num_heads=8,
|
||||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
|
|
Loading…
Reference in New Issue
Block a user