diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 6ab7cf9c..c1a38eaf 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -78,7 +78,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): def get_related_audio_for_index(self, index): if self.extra_samples <= 0: - return None + return None, 0 audiopath = self.audiopaths[index] related_files = find_files_of_type('img', os.path.dirname(audiopath), qualifier=is_audio_file)[0] assert audiopath in related_files @@ -123,10 +123,11 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): output = { 'clip': audio_norm, - 'alt_clips': alt_files, - 'num_alt_clips': actual_samples, # We need to pad so that the dataloader can combine these. 'path': filename, } + if self.extra_samples > 0: + output['alt_clips'] = alt_files + output['num_alt_clips'] = actual_samples return output def __len__(self): diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index a22c57fe..0fbf70ab 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -9,7 +9,7 @@ from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner from models.vqvae.vqvae import Quantize from trainer.networks import register_model import models.gpt_voice.my_dvae as mdvae -from utils.util import checkpoint +from utils.util import checkpoint, get_mask_from_lengths class DiscreteEncoder(nn.Module): @@ -66,7 +66,6 @@ class DiffusionDVAE(nn.Module): spectrogram_channels=80, spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform. dropout=0, - # 106496 -> 26624 -> 6656 -> 16664 -> 416 -> 104 -> 26 for ~5secs@22050Hz channel_mult=(1, 2, 4, 8, 16, 32, 64), attention_resolutions=(16,32,64), conv_resample=True, @@ -81,6 +80,7 @@ class DiffusionDVAE(nn.Module): quantize_dim=1024, num_discrete_codes=8192, scale_steps=4, + conditioning_inputs_provided=True, ): super().__init__() @@ -121,9 +121,11 @@ class DiffusionDVAE(nn.Module): nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) - self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim) - self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim) - self.embedding_combiner = EmbeddingCombiner(time_embed_dim) + + if conditioning_inputs_provided: + self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim) + self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim) + self.embedding_combiner = EmbeddingCombiner(time_embed_dim) self.input_blocks = nn.ModuleList( [ @@ -262,7 +264,7 @@ class DiffusionDVAE(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs): + def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals): spec_hs = self.decoder(embeddings)[::-1] # Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.) spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)] @@ -272,11 +274,12 @@ class DiffusionDVAE(nn.Module): hs = [] emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) if conditioning_inputs is not None: + mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add. emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1) emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1) + emb = self.embedding_combiner(emb, mask, self.query_gen(spec_hs[0])) else: - emb = emb1.unsqueeze(1) - emb = self.embedding_combiner(emb, self.query_gen(spec_hs[0])) + emb = emb1 # The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder. next_spec = spec_hs.pop(0) @@ -302,12 +305,12 @@ class DiffusionDVAE(nn.Module): h = h.type(x.dtype) return self.out(h) - def decode(self, x, timesteps, codes, conditioning_inputs=None): + def decode(self, x, timesteps, codes, conditioning_inputs=None, num_conditioning_signals=None): assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. embeddings = self.quantizer.embed_code(codes).permute((0,2,1)) - return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss + return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals) - def forward(self, x, timesteps, spectrogram, conditioning_inputs=None): + def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None): assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. # Compute DVAE portion first. @@ -319,7 +322,7 @@ class DiffusionDVAE(nn.Module): else: # Compute from codes only. embeddings = self.quantizer.embed_code(codes).permute((0,2,1)) - return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss + return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss @register_model @@ -329,9 +332,10 @@ def register_unet_diffusion_dvae(opt_net, opt): # Test for ~4 second audio clip at 22050Hz if __name__ == '__main__': - clip = torch.randn(1, 1, 81920) - spec = torch.randn(1, 80, 416) - cond = torch.randn(1, 5, 80, 200) - ts = torch.LongTensor([555]) + clip = torch.randn(4, 1, 81920) + spec = torch.randn(4, 80, 416) + cond = torch.randn(4, 5, 80, 200) + num_cond = torch.tensor([2,4,5,3], dtype=torch.long) + ts = torch.LongTensor([432, 234, 100, 555]) model = DiffusionDVAE(32, 2) - print(model(clip, ts, spec, cond)[0].shape) + print(model(clip, ts, spec, cond, num_cond)[0].shape) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 0980ae8a..5ee1fc13 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -772,7 +772,7 @@ class GaussianDiffusion: model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) model_output = model_outputs[0] if len(model_outputs) > 1: - terms['extra_outputs']: model_outputs[1:] + terms['extra_outputs'] = model_outputs[1:] if self.model_var_type in [ ModelVarType.LEARNED, diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index 42c39715..af5d17f4 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -315,14 +315,14 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - def forward(self, x): - return checkpoint(self._forward, x) + def forward(self, x, mask=None): + return checkpoint(self._forward, x, mask) - def _forward(self, x): + def _forward(self, x, mask): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) + h = self.attention(qkv, mask) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) @@ -356,7 +356,7 @@ class QKVAttentionLegacy(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv): + def forward(self, qkv, mask=None): """ Apply QKV attention. @@ -372,7 +372,12 @@ class QKVAttentionLegacy(nn.Module): "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) @staticmethod @@ -389,7 +394,7 @@ class QKVAttention(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv): + def forward(self, qkv, mask=None): """ Apply QKV attention. @@ -406,6 +411,10 @@ class QKVAttention(nn.Module): (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py index 84fd58fa..517d0e5c 100644 --- a/codes/models/gpt_voice/mini_encoder.py +++ b/codes/models/gpt_voice/mini_encoder.py @@ -78,14 +78,14 @@ class QueryProvidedAttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - def forward(self, qx, kvx): - return checkpoint(self._forward, qx, kvx) + def forward(self, qx, kvx, mask=None): + return checkpoint(self._forward, qx, kvx, mask) - def _forward(self, qx, kvx): + def _forward(self, qx, kvx, mask=None): q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1) kv = self.kv(self.norm(kvx.permute(0,2,1))) qkv = torch.cat([q, kv], dim=1) - h = self.attention(qkv) + h = self.attention(qkv, mask) h = self.proj_out(h) return kvx + h.permute(0,2,1) @@ -100,14 +100,14 @@ class EmbeddingCombiner(nn.Module): # x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim # cond: (b,d) or None - def forward(self, x_s, cond=None): + def forward(self, x_s, attn_mask=None, cond=None): assert cond is not None and self.cond_provided or cond is None and not self.cond_provided y = x_s for blk in self.attn: if self.cond_provided: - y = blk(cond, y) + y = blk(cond, y, mask=attn_mask) else: - y = blk(y) + y = blk(y, mask=attn_mask) return y[:, 0] diff --git a/codes/scripts/audio/preparation/split_on_silence.py b/codes/scripts/audio/preparation/split_on_silence.py index 7f057a24..a53f60d3 100644 --- a/codes/scripts/audio/preparation/split_on_silence.py +++ b/codes/scripts/audio/preparation/split_on_silence.py @@ -19,8 +19,8 @@ def main(): maximum_duration = 20 files = find_audio_files(args.path, include_nonwav=True) for e, wav_file in enumerate(tqdm(files)): - #if e < 1459: - # continue + if e < 2759: + continue print(f"Processing {wav_file}..") outdir = os.path.join(args.out, f'{e}_{os.path.basename(wav_file[:-4])}').replace('.', '').strip() os.makedirs(outdir, exist_ok=True) diff --git a/codes/train.py b/codes/train.py index 6ce62a94..f4de0a21 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mydvae_audio_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_dvae_clips.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 5d1ee799..90a7ba69 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -4,7 +4,7 @@ import torch.nn from kornia.augmentation import RandomResizedCrop from torch.cuda.amp import autocast -from trainer.inject import Injector +from trainer.inject import Injector, create_injector from trainer.losses import extract_params_from_state from utils.util import opt_get from utils.weight_scheduler import get_scheduler_for_opt diff --git a/codes/utils/util.py b/codes/utils/util.py index ac0c9888..bff74c10 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -408,4 +408,12 @@ def denormalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): ten = x.clone().permute(1, 2, 3, 0) for t, m, s in zip(ten, mean, std): t.mul_(s).add_(m) - return torch.clamp(ten, 0, 1).permute(3, 0, 1, 2) \ No newline at end of file + return torch.clamp(ten, 0, 1).permute(3, 0, 1, 2) + + +def get_mask_from_lengths(lengths, max_len=None): + if max_len is None: + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device) + mask = (ids < lengths.unsqueeze(1)).bool() + return mask