forked from mrq/DL-Art-School
Merge branch 'master' of https://github.com/neonbjb/DL-Art-School
This commit is contained in:
commit
b19b0a74da
|
@ -100,6 +100,7 @@ class TransformerDiffusion(nn.Module):
|
|||
use_fp16=False,
|
||||
ar_prior=False,
|
||||
new_code_expansion=False,
|
||||
permute_codes=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# Parameters for re-training head
|
||||
|
@ -116,6 +117,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.enable_fp16 = use_fp16
|
||||
self.new_code_expansion = new_code_expansion
|
||||
self.permute_codes = permute_codes
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||
|
||||
|
@ -210,8 +212,8 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
def timestep_independent(self, prior, expected_seq_len):
|
||||
if self.new_code_expansion:
|
||||
code_emb = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
||||
code_emb = self.ar_input(code_emb) if self.ar_prior else self.input_converter(code_emb)
|
||||
prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
||||
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
|
||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
|
@ -228,6 +230,8 @@ class TransformerDiffusion(nn.Module):
|
|||
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
|
||||
if precomputed_code_embeddings is not None:
|
||||
assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||
if self.permute_codes:
|
||||
codes = codes.permute(0,2,1)
|
||||
|
||||
unused_params = []
|
||||
if conditioning_free:
|
||||
|
@ -605,6 +609,15 @@ def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
|
|||
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
|
||||
|
||||
|
||||
def test_tfd():
|
||||
clip = torch.randn(2,256,400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||
prenet_channels=1024, num_heads=3, permute_codes=True,
|
||||
input_vec_dim=256, num_layers=12, prenet_layers=4,
|
||||
dropout=.1)
|
||||
model(clip, ts, clip)
|
||||
|
||||
def test_quant_model():
|
||||
clip = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
|
@ -732,14 +745,14 @@ def test_cheater_model():
|
|||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512,
|
||||
model_channels=1024, contraction_dim=512,
|
||||
prenet_channels=1024, num_heads=8,
|
||||
input_vec_dim=256, num_layers=12, prenet_layers=6,
|
||||
model_channels=1536, contraction_dim=768,
|
||||
prenet_channels=1024, num_heads=12,
|
||||
input_vec_dim=256, num_layers=20, prenet_layers=6,
|
||||
dropout=.1, new_code_expansion=True,
|
||||
)
|
||||
diff_weights = torch.load('extracted_diff.pth')
|
||||
model.diff.load_state_dict(diff_weights, strict=False)
|
||||
cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\19500_generator_ema.pth')
|
||||
#diff_weights = torch.load('extracted_diff.pth')
|
||||
#model.diff.load_state_dict(diff_weights, strict=False)
|
||||
cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\60000_generator_ema.pth')
|
||||
cheater_ar = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4,
|
||||
vqargs= {'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
|
@ -767,4 +780,4 @@ def extract_diff(in_f, out_f, remove_head=False):
|
|||
|
||||
if __name__ == '__main__':
|
||||
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
|
||||
test_cheater_model()
|
||||
test_tfd()
|
||||
|
|
97
codes/scripts/audio/prep_music/generate_long_mels.py
Normal file
97
codes/scripts/audio/prep_music/generate_long_mels.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
|
||||
|
||||
"""
|
||||
Master script that processes all MP3 files found in an input directory. Splits those files up into sub-files of a
|
||||
predetermined duration.
|
||||
"""
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.util import find_audio_files, find_files_of_type
|
||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||
from utils.util import load_audio
|
||||
|
||||
|
||||
def report_progress(progress_file, file):
|
||||
with open(progress_file, 'a', encoding='utf-8') as f:
|
||||
f.write(f'{file}\n')
|
||||
|
||||
|
||||
spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
||||
'true_normalization': True, 'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda()
|
||||
|
||||
def produce_mel(audio):
|
||||
return spec_fn({'in': audio.unsqueeze(0)})['out'].squeeze(0)
|
||||
|
||||
|
||||
def process_folder(folder, base_path, output_path, progress_file, max_duration, sampling_rate=22050):
|
||||
outdir = os.path.join(output_path, f'{os.path.relpath(folder, base_path)}')
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
files = list(os.listdir(folder))
|
||||
i = 0
|
||||
output_i = 0
|
||||
while i < len(files):
|
||||
last_ordinal = -1
|
||||
total_progress = 0
|
||||
to_combine = []
|
||||
while i < len(files) and total_progress < max_duration:
|
||||
audio_file = os.path.join(folder, files[i], "no_vocals.wav")
|
||||
if not os.path.exists(audio_file):
|
||||
break
|
||||
to_combine.append(load_audio(audio_file, 22050))
|
||||
file_ordinal = int(files[i])
|
||||
if last_ordinal != -1 and file_ordinal != last_ordinal+1:
|
||||
break
|
||||
else:
|
||||
i += 1
|
||||
total_progress += 30
|
||||
if total_progress > 30:
|
||||
combined = torch.cat(to_combine, dim=-1).cuda()
|
||||
mel = produce_mel(combined).cpu().numpy()
|
||||
np.savez(os.path.join(outdir, f'{output_i}'), mel)
|
||||
output_i += 1
|
||||
report_progress(progress_file, folder)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path', type=str, help='Path to search for files', default='Y:\\separated')
|
||||
parser.add_argument('--progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\separated\\large_mels\\already_processed.txt')
|
||||
parser.add_argument('--output_path', type=str, help='Path for output files', default='Y:\\separated\\large_mels')
|
||||
parser.add_argument('--num_threads', type=int, help='Number of concurrent workers processing files.', default=3)
|
||||
parser.add_argument('--max_duration', type=int, help='Duration per clip in seconds', default=120)
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
processed_files = set()
|
||||
if os.path.exists(args.progress_file):
|
||||
with open(args.progress_file, 'r', encoding='utf-8') as f:
|
||||
for line in f.readlines():
|
||||
processed_files.add(line.strip())
|
||||
|
||||
cache_path = os.path.join(args.output_path, 'cache.pth')
|
||||
if os.path.exists(cache_path):
|
||||
root_music_files = torch.load(cache_path)
|
||||
else:
|
||||
path = Path(args.path)
|
||||
def collect(p):
|
||||
return str(os.path.dirname(os.path.dirname(p)))
|
||||
root_music_files = {collect(p) for p in path.rglob("*no_vocals.wav")}
|
||||
torch.save(root_music_files, cache_path)
|
||||
|
||||
orig_len = len(root_music_files)
|
||||
folders = root_music_files - processed_files
|
||||
print(f"Found {len(folders)} files to process. Total processing is {100 * (orig_len - len(folders)) / orig_len}% complete.")
|
||||
|
||||
with ThreadPool(args.num_threads) as pool:
|
||||
list(tqdm(pool.imap(functools.partial(process_folder, output_path=args.output_path, base_path=args.path,
|
||||
progress_file=args.progress_file, max_duration=args.max_duration), folders), total=len(folders)))
|
|
@ -79,24 +79,18 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
return list(glob(f'{path}/*.wav'))
|
||||
|
||||
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
real_resampled = audio
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
output_shape = (1, 16, audio.shape[-1] // 16)
|
||||
output_shape = (1, 256, audio.shape[-1] // 256)
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||
model_kwargs={'aligned_conditioning': mel})
|
||||
gen = pixel_shuffle_1d(gen, 16)
|
||||
model_kwargs={'codes': mel})
|
||||
gen = pixel_shuffle_1d(gen, 256)
|
||||
|
||||
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
|
||||
|
||||
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
real_resampled = audio
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
|
@ -116,10 +110,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
||||
|
||||
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
real_resampled = audio
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
|
@ -148,10 +139,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
||||
|
||||
def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
real_resampled = audio
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
|
@ -174,10 +162,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
||||
|
||||
def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
real_resampled = audio
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
|
@ -273,17 +258,17 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator',
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41500_generator_ema.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_retry\\models\\22000_generator_ema.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 200,
|
||||
'conditioning_free': True, 'conditioning_free_k': 2,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
|
||||
'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode',
|
||||
#'partial_low': 128, 'partial_high': 192
|
||||
}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 605, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 100, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user