For loading conditional clips, default to falling back to loading the clip itself
This commit is contained in:
parent
5ae7e0d9b0
commit
f2cd6a7f08
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -189,8 +190,8 @@ if __name__ == '__main__':
|
||||||
'use_bpe_tokenizer': False,
|
'use_bpe_tokenizer': False,
|
||||||
},
|
},
|
||||||
'unsupervised_audio_args': {
|
'unsupervised_audio_args': {
|
||||||
'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'],
|
'path': ['Y:\\clips\\podcasts-0\\6175_20170425-How the National Security Council Works'],
|
||||||
'cache_path': 'test_cache_delete_me.pth',
|
'cache_path': 'test_cache_delete_me2.pth',
|
||||||
},
|
},
|
||||||
'text_corpus_args': {
|
'text_corpus_args': {
|
||||||
'corpi': [['bookcorpus', '']],
|
'corpi': [['bookcorpus', '']],
|
||||||
|
@ -216,6 +217,7 @@ if __name__ == '__main__':
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
os.remove('test_cache_delete_me2.pth')
|
||||||
|
|
||||||
ds, c = create_dataset(train_params, return_collate=True)
|
ds, c = create_dataset(train_params, return_collate=True)
|
||||||
dl = create_dataloader(ds, train_params, collate_fn=c)
|
dl = create_dataloader(ds, train_params, collate_fn=c)
|
||||||
|
@ -233,14 +235,14 @@ if __name__ == '__main__':
|
||||||
m = None
|
m = None
|
||||||
for i, b in tqdm(enumerate(dl)):
|
for i, b in tqdm(enumerate(dl)):
|
||||||
for ib in range(batch_sz):
|
for ib in range(batch_sz):
|
||||||
save(b, i, ib, 'paired_audio')
|
#save(b, i, ib, 'paired_audio')
|
||||||
save(b, i, ib, 'paired_audio_conditioning', 0)
|
#save(b, i, ib, 'paired_audio_conditioning', 0)
|
||||||
save(b, i, ib, 'paired_audio_conditioning', 1)
|
#save(b, i, ib, 'paired_audio_conditioning', 1)
|
||||||
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
|
#print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
|
||||||
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
|
#print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
|
||||||
#save(b, i, ib, 'speech_audio')
|
save(b, i, ib, 'speech_audio')
|
||||||
#save(b, i, ib, 'speech_audio_conditioning', 0)
|
save(b, i, ib, 'speech_audio_conditioning', 0)
|
||||||
#save(b, i, ib, 'speech_audio_conditioning', 1)
|
save(b, i, ib, 'speech_audio_conditioning', 1)
|
||||||
#print(f'Text: {b["text_text"][ib]}')
|
#print(f'Text: {b["text_text"][ib]}')
|
||||||
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
|
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
|
||||||
if i > 5:
|
if i > 5:
|
||||||
|
|
|
@ -51,19 +51,21 @@ def load_audio(audiopath, sampling_rate):
|
||||||
return audio.unsqueeze(0)
|
return audio.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True):
|
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, fallback_to_self=True):
|
||||||
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
|
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
|
||||||
candidates = []
|
candidates = []
|
||||||
if os.path.exists(sim_path):
|
if os.path.exists(sim_path):
|
||||||
similarities = torch.load(sim_path)
|
similarities = torch.load(sim_path)
|
||||||
fname = os.path.basename(path)
|
fname = os.path.basename(path)
|
||||||
if fname in similarities.keys():
|
if fname in similarities.keys():
|
||||||
candidates = similarities[fname]
|
candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]]
|
||||||
else:
|
else:
|
||||||
print(f'Similarities list found for {path} but {fname} was not in that list.')
|
print(f'Similarities list found for {path} but {fname} was not in that list.')
|
||||||
if len(candidates) == 0:
|
if len(candidates) == 0:
|
||||||
print(f"Falling back to non-similarity list for {path}")
|
if fallback_to_self:
|
||||||
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
candidates = [path]
|
||||||
|
else:
|
||||||
|
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||||
|
|
||||||
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||||
if not include_self:
|
if not include_self:
|
||||||
|
@ -75,7 +77,7 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True)
|
||||||
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
||||||
related_clips = []
|
related_clips = []
|
||||||
for k in range(n):
|
for k in range(n):
|
||||||
rel_clip = load_audio(os.path.join(os.path.dirname(path), random.choice(candidates)), sample_rate)
|
rel_clip = load_audio(random.choice(candidates), sample_rate)
|
||||||
gap = rel_clip.shape[-1] - sample_length
|
gap = rel_clip.shape[-1] - sample_length
|
||||||
if gap < 0:
|
if gap < 0:
|
||||||
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||||
|
|
|
@ -45,6 +45,8 @@ def process_subdir(subdir, options, clip_sz):
|
||||||
clip_model = load_model_from_config(preloaded_options=options, model_name='clip', also_load_savepoint=True)
|
clip_model = load_model_from_config(preloaded_options=options, model_name='clip', also_load_savepoint=True)
|
||||||
|
|
||||||
root, paths = subdir
|
root, paths = subdir
|
||||||
|
if len(paths) == 0:
|
||||||
|
return
|
||||||
root = str(root)
|
root = str(root)
|
||||||
output_file = os.path.join(root, 'similarities.pth')
|
output_file = os.path.join(root, 'similarities.pth')
|
||||||
if os.path.exists(output_file):
|
if os.path.exists(output_file):
|
||||||
|
@ -100,8 +102,8 @@ if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-o', type=str, help='Path to the options YAML file used to train the CLIP model', default='../options/train_voice_voice_clip.yml')
|
parser.add_argument('-o', type=str, help='Path to the options YAML file used to train the CLIP model', default='../options/train_voice_voice_clip.yml')
|
||||||
parser.add_argument('--num_workers', type=int, help='Number concurrent processes to use', default=1)
|
parser.add_argument('--num_workers', type=int, help='Number concurrent processes to use', default=6)
|
||||||
parser.add_argument('--root_path', type=str, help='Root path to search for audio directories from', default='Y:\\clips\\podcasts-0\\5177_20190625-Food Waste is Solvable')
|
parser.add_argument('--root_path', type=str, help='Root path to search for audio directories from', default='Y:\\bigasr_dataset\\tedlium')
|
||||||
parser.add_argument('--clip_size', type=int, help='Amount of audio samples to pull from each file', default=22050)
|
parser.add_argument('--clip_size', type=int, help='Amount of audio samples to pull from each file', default=22050)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user