Move gen_similarities and rename
This commit is contained in:
parent
8b376e63d9
commit
735f6e4640
|
@ -38,68 +38,71 @@ def process_subdir(subdir, options, clip_sz):
|
||||||
global clip_model
|
global clip_model
|
||||||
if clip_model is None:
|
if clip_model is None:
|
||||||
print('Loading CLIP model..')
|
print('Loading CLIP model..')
|
||||||
clip_model = load_model_from_config(preloaded_options=options, model_name='clip', also_load_savepoint=True)
|
clip_model = load_model_from_config(preloaded_options=options, model_name='clip', also_load_savepoint=True).cuda()
|
||||||
|
clip_model.eval()
|
||||||
|
|
||||||
root, paths = subdir
|
with torch.no_grad():
|
||||||
if len(paths) == 0:
|
root, paths = subdir
|
||||||
return
|
if len(paths) == 0:
|
||||||
root = str(root)
|
return
|
||||||
output_file = os.path.join(root, 'similarities.pth')
|
root = str(root)
|
||||||
if os.path.exists(output_file):
|
output_file = os.path.join(root, 'similarities.pth')
|
||||||
print(f'{root} already processed. Skipping.')
|
if os.path.exists(output_file):
|
||||||
return
|
print(f'{root} already processed. Skipping.')
|
||||||
print(f'Processing {root}..')
|
return
|
||||||
|
print(f'Processing {root}..')
|
||||||
|
|
||||||
clips = []
|
clips = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
clip = load_audio(str(path), 22050)
|
clip = load_audio(str(path), 22050)
|
||||||
padding = clip_sz - clip.shape[1]
|
padding = clip_sz - clip.shape[1]
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
clip = F.pad(clip, (0, padding))
|
clip = F.pad(clip, (0, padding))
|
||||||
elif padding < 0:
|
elif padding < 0:
|
||||||
clip = clip[:, :clip_sz]
|
clip = clip[:, :clip_sz]
|
||||||
clips.append(clip)
|
clips.append(clip)
|
||||||
sims = None
|
sims = None
|
||||||
while len(clips) > 0:
|
while len(clips) > 0:
|
||||||
stacked = torch.stack(clips[:256], dim=0).cuda()
|
stacked = torch.stack(clips[:256], dim=0).cuda()
|
||||||
clips = clips[256:]
|
clips = clips[256:]
|
||||||
mels = wav_to_mel(stacked)
|
mels = wav_to_mel(stacked).cuda()
|
||||||
outp = clip_model.inference(mels).cpu()
|
outp = clip_model.inference(mels).cpu()
|
||||||
if sims is None:
|
if sims is None:
|
||||||
sims = outp
|
sims = outp
|
||||||
else:
|
else:
|
||||||
if outp.shape[-1] != 256:
|
if outp.shape[-1] != 256:
|
||||||
outp = F.pad(outp, (0,256-outp.shape[-1]))
|
outp = F.pad(outp, (0,256-outp.shape[-1]))
|
||||||
sims = torch.cat([sims, outp], dim=0)
|
sims = torch.cat([sims, outp], dim=0)
|
||||||
|
|
||||||
simmap = {}
|
simmap = {}
|
||||||
# TODO: this can be further improved. We're just taking the topk here but, there is no gaurantee that there is 3
|
# TODO: this can be further improved. We're just taking the topk here but, there is no gaurantee that there is 3
|
||||||
# samples from the same speaker in any given folder.
|
# samples from the same speaker in any given folder.
|
||||||
for path, sim in zip(paths, sims):
|
for path, sim in zip(paths, sims):
|
||||||
n = min(4, len(sim))
|
n = min(4, len(sim))
|
||||||
top3 = torch.topk(sim, n)
|
top3 = torch.topk(sim, n)
|
||||||
rel = os.path.relpath(str(path), root)
|
rel = os.path.relpath(str(path), root)
|
||||||
simpaths = []
|
simpaths = []
|
||||||
if n == 1:
|
if n == 1:
|
||||||
simpaths.append(rel)
|
simpaths.append(rel)
|
||||||
else:
|
else:
|
||||||
for i in range(1,n): # The first entry is always the file itself.
|
for i in range(1,n): # The first entry is always the file itself.
|
||||||
top_ind = top3.indices[i]
|
top_ind = top3.indices[i]
|
||||||
simpaths.append(os.path.relpath(paths[top_ind], root))
|
simpaths.append(os.path.relpath(paths[top_ind], root))
|
||||||
simmap[rel] = simpaths
|
simmap[rel] = simpaths
|
||||||
torch.save(simmap, output_file)
|
torch.save(simmap, output_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
This script iterates within a directory filled with subdirs. Each subdir contains a list of audio files from the same
|
This script iterates within a directory filled with subdirs. Each subdir contains a list of audio files from the same
|
||||||
source. The script uses an speech-to-speech clip model to find the <n> most similar audio clips within each subdir for
|
source. The script uses an speech-to-speech clip model to find the <n> most similar audio clips within each subdir for
|
||||||
each clip within that subdir.
|
each clip within that subdir. These similar files are recorded in a "similarities.pth" file in each subdirectory, which
|
||||||
|
is consumed during training when the dataset searches for conditioning clips.
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-o', type=str, help='Path to the options YAML file used to train the CLIP model', default='../options/train_voice_voice_clip.yml')
|
parser.add_argument('-o', type=str, help='Path to the options YAML file used to train the CLIP model', default='../options/train_voice_voice_clip.yml')
|
||||||
parser.add_argument('--num_workers', type=int, help='Number concurrent processes to use', default=6)
|
parser.add_argument('--num_workers', type=int, help='Number concurrent processes to use', default=6)
|
||||||
parser.add_argument('--root_path', type=str, help='Root path to search for audio directories from', default='Y:\\bigasr_dataset\\tedlium')
|
parser.add_argument('--root_path', type=str, help='Root path to search for audio directories from', default='Y:\\filtered\\big_podcast')
|
||||||
parser.add_argument('--clip_size', type=int, help='Amount of audio samples to pull from each file', default=22050)
|
parser.add_argument('--clip_size', type=int, help='Amount of audio samples to pull from each file', default=22050)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user