For loading conditional clips, default to falling back to loading the clip itself

This commit is contained in:
James Betker 2021-12-30 09:10:14 -07:00
parent 5ae7e0d9b0
commit f2cd6a7f08
3 changed files with 23 additions and 17 deletions

View File

@ -1,6 +1,7 @@
import os import os
import os import os
import random import random
import shutil
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -189,8 +190,8 @@ if __name__ == '__main__':
'use_bpe_tokenizer': False, 'use_bpe_tokenizer': False,
}, },
'unsupervised_audio_args': { 'unsupervised_audio_args': {
'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'], 'path': ['Y:\\clips\\podcasts-0\\6175_20170425-How the National Security Council Works'],
'cache_path': 'test_cache_delete_me.pth', 'cache_path': 'test_cache_delete_me2.pth',
}, },
'text_corpus_args': { 'text_corpus_args': {
'corpi': [['bookcorpus', '']], 'corpi': [['bookcorpus', '']],
@ -216,6 +217,7 @@ if __name__ == '__main__':
}, },
} }
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
os.remove('test_cache_delete_me2.pth')
ds, c = create_dataset(train_params, return_collate=True) ds, c = create_dataset(train_params, return_collate=True)
dl = create_dataloader(ds, train_params, collate_fn=c) dl = create_dataloader(ds, train_params, collate_fn=c)
@ -233,14 +235,14 @@ if __name__ == '__main__':
m = None m = None
for i, b in tqdm(enumerate(dl)): for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz): for ib in range(batch_sz):
save(b, i, ib, 'paired_audio') #save(b, i, ib, 'paired_audio')
save(b, i, ib, 'paired_audio_conditioning', 0) #save(b, i, ib, 'paired_audio_conditioning', 0)
save(b, i, ib, 'paired_audio_conditioning', 1) #save(b, i, ib, 'paired_audio_conditioning', 1)
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}') #print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}') #print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
#save(b, i, ib, 'speech_audio') save(b, i, ib, 'speech_audio')
#save(b, i, ib, 'speech_audio_conditioning', 0) save(b, i, ib, 'speech_audio_conditioning', 0)
#save(b, i, ib, 'speech_audio_conditioning', 1) save(b, i, ib, 'speech_audio_conditioning', 1)
#print(f'Text: {b["text_text"][ib]}') #print(f'Text: {b["text_text"][ib]}')
#print(f'Text decoded: {decode(b, ib, "text_tokens")}') #print(f'Text decoded: {decode(b, ib, "text_tokens")}')
if i > 5: if i > 5:

View File

@ -51,19 +51,21 @@ def load_audio(audiopath, sampling_rate):
return audio.unsqueeze(0) return audio.unsqueeze(0)
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True): def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, fallback_to_self=True):
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth') sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
candidates = [] candidates = []
if os.path.exists(sim_path): if os.path.exists(sim_path):
similarities = torch.load(sim_path) similarities = torch.load(sim_path)
fname = os.path.basename(path) fname = os.path.basename(path)
if fname in similarities.keys(): if fname in similarities.keys():
candidates = similarities[fname] candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]]
else: else:
print(f'Similarities list found for {path} but {fname} was not in that list.') print(f'Similarities list found for {path} but {fname} was not in that list.')
if len(candidates) == 0: if len(candidates) == 0:
print(f"Falling back to non-similarity list for {path}") if fallback_to_self:
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] candidates = [path]
else:
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if not include_self: if not include_self:
@ -75,7 +77,7 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True)
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_clips = [] related_clips = []
for k in range(n): for k in range(n):
rel_clip = load_audio(os.path.join(os.path.dirname(path), random.choice(candidates)), sample_rate) rel_clip = load_audio(random.choice(candidates), sample_rate)
gap = rel_clip.shape[-1] - sample_length gap = rel_clip.shape[-1] - sample_length
if gap < 0: if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))

View File

@ -45,6 +45,8 @@ def process_subdir(subdir, options, clip_sz):
clip_model = load_model_from_config(preloaded_options=options, model_name='clip', also_load_savepoint=True) clip_model = load_model_from_config(preloaded_options=options, model_name='clip', also_load_savepoint=True)
root, paths = subdir root, paths = subdir
if len(paths) == 0:
return
root = str(root) root = str(root)
output_file = os.path.join(root, 'similarities.pth') output_file = os.path.join(root, 'similarities.pth')
if os.path.exists(output_file): if os.path.exists(output_file):
@ -100,8 +102,8 @@ if __name__ == '__main__':
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-o', type=str, help='Path to the options YAML file used to train the CLIP model', default='../options/train_voice_voice_clip.yml') parser.add_argument('-o', type=str, help='Path to the options YAML file used to train the CLIP model', default='../options/train_voice_voice_clip.yml')
parser.add_argument('--num_workers', type=int, help='Number concurrent processes to use', default=1) parser.add_argument('--num_workers', type=int, help='Number concurrent processes to use', default=6)
parser.add_argument('--root_path', type=str, help='Root path to search for audio directories from', default='Y:\\clips\\podcasts-0\\5177_20190625-Food Waste is Solvable') parser.add_argument('--root_path', type=str, help='Root path to search for audio directories from', default='Y:\\bigasr_dataset\\tedlium')
parser.add_argument('--clip_size', type=int, help='Amount of audio samples to pull from each file', default=22050) parser.add_argument('--clip_size', type=int, help='Amount of audio samples to pull from each file', default=22050)
args = parser.parse_args() args = parser.parse_args()