pull/9/head
James Betker 2022-06-19 18:56:27 +07:00
commit b19b0a74da
3 changed files with 133 additions and 38 deletions

@ -100,6 +100,7 @@ class TransformerDiffusion(nn.Module):
use_fp16=False,
ar_prior=False,
new_code_expansion=False,
permute_codes=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for re-training head
@ -116,6 +117,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.new_code_expansion = new_code_expansion
self.permute_codes = permute_codes
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
@ -210,8 +212,8 @@ class TransformerDiffusion(nn.Module):
def timestep_independent(self, prior, expected_seq_len):
if self.new_code_expansion:
code_emb = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
code_emb = self.ar_input(code_emb) if self.ar_prior else self.input_converter(code_emb)
prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
@ -228,6 +230,8 @@ class TransformerDiffusion(nn.Module):
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
if precomputed_code_embeddings is not None:
assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
if self.permute_codes:
codes = codes.permute(0,2,1)
unused_params = []
if conditioning_free:
@ -605,6 +609,15 @@ def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
def test_tfd():
clip = torch.randn(2,256,400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
prenet_channels=1024, num_heads=3, permute_codes=True,
input_vec_dim=256, num_layers=12, prenet_layers=4,
dropout=.1)
model(clip, ts, clip)
def test_quant_model():
clip = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
@ -732,14 +745,14 @@ def test_cheater_model():
# For music:
model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512,
model_channels=1024, contraction_dim=512,
prenet_channels=1024, num_heads=8,
input_vec_dim=256, num_layers=12, prenet_layers=6,
model_channels=1536, contraction_dim=768,
prenet_channels=1024, num_heads=12,
input_vec_dim=256, num_layers=20, prenet_layers=6,
dropout=.1, new_code_expansion=True,
)
diff_weights = torch.load('extracted_diff.pth')
model.diff.load_state_dict(diff_weights, strict=False)
cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\19500_generator_ema.pth')
#diff_weights = torch.load('extracted_diff.pth')
#model.diff.load_state_dict(diff_weights, strict=False)
cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\60000_generator_ema.pth')
cheater_ar = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4,
vqargs= {'positional_dims': 1, 'channels': 64,
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
@ -767,4 +780,4 @@ def extract_diff(in_f, out_f, remove_head=False):
if __name__ == '__main__':
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
test_cheater_model()
test_tfd()

@ -0,0 +1,97 @@
"""
Master script that processes all MP3 files found in an input directory. Splits those files up into sub-files of a
predetermined duration.
"""
import argparse
import functools
import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
import torch
import torchaudio
import numpy as np
from tqdm import tqdm
from data.util import find_audio_files, find_files_of_type
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
from utils.util import load_audio
def report_progress(progress_file, file):
with open(progress_file, 'a', encoding='utf-8') as f:
f.write(f'{file}\n')
spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
'true_normalization': True, 'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda()
def produce_mel(audio):
return spec_fn({'in': audio.unsqueeze(0)})['out'].squeeze(0)
def process_folder(folder, base_path, output_path, progress_file, max_duration, sampling_rate=22050):
outdir = os.path.join(output_path, f'{os.path.relpath(folder, base_path)}')
os.makedirs(outdir, exist_ok=True)
files = list(os.listdir(folder))
i = 0
output_i = 0
while i < len(files):
last_ordinal = -1
total_progress = 0
to_combine = []
while i < len(files) and total_progress < max_duration:
audio_file = os.path.join(folder, files[i], "no_vocals.wav")
if not os.path.exists(audio_file):
break
to_combine.append(load_audio(audio_file, 22050))
file_ordinal = int(files[i])
if last_ordinal != -1 and file_ordinal != last_ordinal+1:
break
else:
i += 1
total_progress += 30
if total_progress > 30:
combined = torch.cat(to_combine, dim=-1).cuda()
mel = produce_mel(combined).cpu().numpy()
np.savez(os.path.join(outdir, f'{output_i}'), mel)
output_i += 1
report_progress(progress_file, folder)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, help='Path to search for files', default='Y:\\separated')
parser.add_argument('--progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\separated\\large_mels\\already_processed.txt')
parser.add_argument('--output_path', type=str, help='Path for output files', default='Y:\\separated\\large_mels')
parser.add_argument('--num_threads', type=int, help='Number of concurrent workers processing files.', default=3)
parser.add_argument('--max_duration', type=int, help='Duration per clip in seconds', default=120)
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
processed_files = set()
if os.path.exists(args.progress_file):
with open(args.progress_file, 'r', encoding='utf-8') as f:
for line in f.readlines():
processed_files.add(line.strip())
cache_path = os.path.join(args.output_path, 'cache.pth')
if os.path.exists(cache_path):
root_music_files = torch.load(cache_path)
else:
path = Path(args.path)
def collect(p):
return str(os.path.dirname(os.path.dirname(p)))
root_music_files = {collect(p) for p in path.rglob("*no_vocals.wav")}
torch.save(root_music_files, cache_path)
orig_len = len(root_music_files)
folders = root_music_files - processed_files
print(f"Found {len(folders)} files to process. Total processing is {100 * (orig_len - len(folders)) / orig_len}% complete.")
with ThreadPool(args.num_threads) as pool:
list(tqdm(pool.imap(functools.partial(process_folder, output_path=args.output_path, base_path=args.path,
progress_file=args.progress_file, max_duration=args.max_duration), folders), total=len(folders)))

@ -79,24 +79,18 @@ class MusicDiffusionFid(evaluator.Evaluator):
return list(glob(f'{path}/*.wav'))
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
real_resampled = audio
audio = audio.unsqueeze(0)
output_shape = (1, 16, audio.shape[-1] // 16)
output_shape = (1, 256, audio.shape[-1] // 256)
mel = self.spec_fn({'in': audio})['out']
gen = self.diffuser.p_sample_loop(self.model, output_shape,
model_kwargs={'aligned_conditioning': mel})
gen = pixel_shuffle_1d(gen, 16)
model_kwargs={'codes': mel})
gen = pixel_shuffle_1d(gen, 256)
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
real_resampled = audio
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
@ -116,10 +110,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
real_resampled = audio
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
@ -148,10 +139,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
real_resampled = audio
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
@ -174,10 +162,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
real_resampled = audio
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
@ -273,17 +258,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator',
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41500_generator_ema.pth'
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_retry\\models\\22000_generator_ema.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 200,
'conditioning_free': True, 'conditioning_free_k': 2,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode',
#'partial_low': 128, 'partial_high': 192
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 605, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 100, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())