Get diffusion_dvae ready for prime time!

This commit is contained in:
James Betker 2021-09-16 22:43:10 -06:00
parent 1197ae1928
commit f78ce9d924
9 changed files with 61 additions and 39 deletions

View File

@ -78,7 +78,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
def get_related_audio_for_index(self, index):
if self.extra_samples <= 0:
return None
return None, 0
audiopath = self.audiopaths[index]
related_files = find_files_of_type('img', os.path.dirname(audiopath), qualifier=is_audio_file)[0]
assert audiopath in related_files
@ -123,10 +123,11 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
output = {
'clip': audio_norm,
'alt_clips': alt_files,
'num_alt_clips': actual_samples, # We need to pad so that the dataloader can combine these.
'path': filename,
}
if self.extra_samples > 0:
output['alt_clips'] = alt_files
output['num_alt_clips'] = actual_samples
return output
def __len__(self):

View File

@ -9,7 +9,7 @@ from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
import models.gpt_voice.my_dvae as mdvae
from utils.util import checkpoint
from utils.util import checkpoint, get_mask_from_lengths
class DiscreteEncoder(nn.Module):
@ -66,7 +66,6 @@ class DiffusionDVAE(nn.Module):
spectrogram_channels=80,
spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform.
dropout=0,
# 106496 -> 26624 -> 6656 -> 16664 -> 416 -> 104 -> 26 for ~5secs@22050Hz
channel_mult=(1, 2, 4, 8, 16, 32, 64),
attention_resolutions=(16,32,64),
conv_resample=True,
@ -81,6 +80,7 @@ class DiffusionDVAE(nn.Module):
quantize_dim=1024,
num_discrete_codes=8192,
scale_steps=4,
conditioning_inputs_provided=True,
):
super().__init__()
@ -121,9 +121,11 @@ class DiffusionDVAE(nn.Module):
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
if conditioning_inputs_provided:
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
self.input_blocks = nn.ModuleList(
[
@ -262,7 +264,7 @@ class DiffusionDVAE(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs):
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals):
spec_hs = self.decoder(embeddings)[::-1]
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
@ -272,11 +274,12 @@ class DiffusionDVAE(nn.Module):
hs = []
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if conditioning_inputs is not None:
mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add.
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
emb = self.embedding_combiner(emb, mask, self.query_gen(spec_hs[0]))
else:
emb = emb1.unsqueeze(1)
emb = self.embedding_combiner(emb, self.query_gen(spec_hs[0]))
emb = emb1
# The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder.
next_spec = spec_hs.pop(0)
@ -302,12 +305,12 @@ class DiffusionDVAE(nn.Module):
h = h.type(x.dtype)
return self.out(h)
def decode(self, x, timesteps, codes, conditioning_inputs=None):
def decode(self, x, timesteps, codes, conditioning_inputs=None, num_conditioning_signals=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals)
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None):
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
# Compute DVAE portion first.
@ -319,7 +322,7 @@ class DiffusionDVAE(nn.Module):
else:
# Compute from codes only.
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss
@register_model
@ -329,9 +332,10 @@ def register_unet_diffusion_dvae(opt_net, opt):
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
clip = torch.randn(1, 1, 81920)
spec = torch.randn(1, 80, 416)
cond = torch.randn(1, 5, 80, 200)
ts = torch.LongTensor([555])
clip = torch.randn(4, 1, 81920)
spec = torch.randn(4, 80, 416)
cond = torch.randn(4, 5, 80, 200)
num_cond = torch.tensor([2,4,5,3], dtype=torch.long)
ts = torch.LongTensor([432, 234, 100, 555])
model = DiffusionDVAE(32, 2)
print(model(clip, ts, spec, cond)[0].shape)
print(model(clip, ts, spec, cond, num_cond)[0].shape)

View File

@ -772,7 +772,7 @@ class GaussianDiffusion:
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
model_output = model_outputs[0]
if len(model_outputs) > 1:
terms['extra_outputs']: model_outputs[1:]
terms['extra_outputs'] = model_outputs[1:]
if self.model_var_type in [
ModelVarType.LEARNED,

View File

@ -315,14 +315,14 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, x)
def forward(self, x, mask=None):
return checkpoint(self._forward, x, mask)
def _forward(self, x):
def _forward(self, x, mask):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.attention(qkv, mask)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
@ -356,7 +356,7 @@ class QKVAttentionLegacy(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
def forward(self, qkv, mask=None):
"""
Apply QKV attention.
@ -372,7 +372,12 @@ class QKVAttentionLegacy(nn.Module):
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
@ -389,7 +394,7 @@ class QKVAttention(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
def forward(self, qkv, mask=None):
"""
Apply QKV attention.
@ -406,6 +411,10 @@ class QKVAttention(nn.Module):
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)

View File

@ -78,14 +78,14 @@ class QueryProvidedAttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, qx, kvx):
return checkpoint(self._forward, qx, kvx)
def forward(self, qx, kvx, mask=None):
return checkpoint(self._forward, qx, kvx, mask)
def _forward(self, qx, kvx):
def _forward(self, qx, kvx, mask=None):
q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1)
kv = self.kv(self.norm(kvx.permute(0,2,1)))
qkv = torch.cat([q, kv], dim=1)
h = self.attention(qkv)
h = self.attention(qkv, mask)
h = self.proj_out(h)
return kvx + h.permute(0,2,1)
@ -100,14 +100,14 @@ class EmbeddingCombiner(nn.Module):
# x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim
# cond: (b,d) or None
def forward(self, x_s, cond=None):
def forward(self, x_s, attn_mask=None, cond=None):
assert cond is not None and self.cond_provided or cond is None and not self.cond_provided
y = x_s
for blk in self.attn:
if self.cond_provided:
y = blk(cond, y)
y = blk(cond, y, mask=attn_mask)
else:
y = blk(y)
y = blk(y, mask=attn_mask)
return y[:, 0]

View File

@ -19,8 +19,8 @@ def main():
maximum_duration = 20
files = find_audio_files(args.path, include_nonwav=True)
for e, wav_file in enumerate(tqdm(files)):
#if e < 1459:
# continue
if e < 2759:
continue
print(f"Processing {wav_file}..")
outdir = os.path.join(args.out, f'{e}_{os.path.basename(wav_file[:-4])}').replace('.', '').strip()
os.makedirs(outdir, exist_ok=True)

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mydvae_audio_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_dvae_clips.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -4,7 +4,7 @@ import torch.nn
from kornia.augmentation import RandomResizedCrop
from torch.cuda.amp import autocast
from trainer.inject import Injector
from trainer.inject import Injector, create_injector
from trainer.losses import extract_params_from_state
from utils.util import opt_get
from utils.weight_scheduler import get_scheduler_for_opt

View File

@ -408,4 +408,12 @@ def denormalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
ten = x.clone().permute(1, 2, 3, 0)
for t, m, s in zip(ten, mean, std):
t.mul_(s).add_(m)
return torch.clamp(ten, 0, 1).permute(3, 0, 1, 2)
return torch.clamp(ten, 0, 1).permute(3, 0, 1, 2)
def get_mask_from_lengths(lengths, max_len=None):
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask