forked from mrq/DL-Art-School
codes generation script
This commit is contained in:
parent
2f4d990ad1
commit
d3a60633a3
|
@ -111,6 +111,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
self.pad_to *= self.sampling_rate
|
self.pad_to *= self.sampling_rate
|
||||||
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
|
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
|
||||||
self.min_length = opt_get(opt, ['min_length'], 0)
|
self.min_length = opt_get(opt, ['min_length'], 0)
|
||||||
|
self.dont_clip = opt_get(opt, ['dont_clip'], False)
|
||||||
|
|
||||||
# "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no
|
# "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no
|
||||||
# guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples,
|
# guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples,
|
||||||
|
@ -126,6 +127,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
audiopath = self.audiopaths[index]
|
audiopath = self.audiopaths[index]
|
||||||
audio = load_audio(audiopath, self.sampling_rate)
|
audio = load_audio(audiopath, self.sampling_rate)
|
||||||
assert audio.shape[1] > self.min_length
|
assert audio.shape[1] > self.min_length
|
||||||
|
if self.dont_clip:
|
||||||
|
assert audio.shape[1] <= self.pad_to
|
||||||
return audio, audiopath
|
return audio, audiopath
|
||||||
|
|
||||||
def get_related_audio_for_index(self, index):
|
def get_related_audio_for_index(self, index):
|
||||||
|
|
41
codes/scripts/audio/preparation/gen_dvae_codes.py
Normal file
41
codes/scripts/audio/preparation/gen_dvae_codes.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from scripts.audio.gen.speech_synthesis_utils import load_speech_dvae, wav_to_mel
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
input_folder = 'C:\\Users\\James\\Downloads\\lex2\\lexfridman_training_mp3'
|
||||||
|
output_folder = 'C:\\Users\\James\\Downloads\\lex2\\quantized'
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'mode': 'unsupervised_audio',
|
||||||
|
'path': [input_folder],
|
||||||
|
'cache_path': f'{input_folder}/cache.pth',
|
||||||
|
'sampling_rate': 22050,
|
||||||
|
'pad_to_samples': 441000,
|
||||||
|
'resample_clip': False,
|
||||||
|
'extra_samples': 0,
|
||||||
|
'phase': 'train',
|
||||||
|
'n_workers': 2,
|
||||||
|
'batch_size': 64,
|
||||||
|
}
|
||||||
|
from data import create_dataset, create_dataloader
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
ds = create_dataset(params)
|
||||||
|
dl = create_dataloader(ds, params)
|
||||||
|
|
||||||
|
dvae = load_speech_dvae().cuda()
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(dl):
|
||||||
|
audio = batch['clip'].cuda()
|
||||||
|
mel = wav_to_mel(audio)
|
||||||
|
codes = dvae.get_codebook_indices(mel)
|
||||||
|
for i in range(audio.shape[0]):
|
||||||
|
c = codes[i, :batch['clip_lengths'][i]//1024+4] # +4 seems empirically to be a good clipping point - it seems to preserve the termination codes.
|
||||||
|
fn = batch['path'][i]
|
||||||
|
outp = os.path.join(output_folder, os.path.relpath(fn, input_folder) + ".pth")
|
||||||
|
os.makedirs(os.path.dirname(outp), exist_ok=True)
|
||||||
|
torch.save(c.tolist(), outp)
|
Loading…
Reference in New Issue
Block a user