diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index f1c56b1b..d9ac38f3 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -22,7 +22,7 @@ def load_tsv(filename): bad_lines = 0 for line in f: components = line.strip().split('\t') - if len(components) < 3: + if len(components) < 2: bad_lines += 1 if bad_lines > 10: print(f'{filename} contains 10+ bad entries. Failing. Sample last entry: {line}') @@ -31,18 +31,28 @@ def load_tsv(filename): return filepaths_and_text +def convert_string_list_to_tensor(strlist): + if strlist.startswith('['): + strlist = strlist[1:] + if strlist.endswith(']'): + strlist = strlist[:-1] + as_ints = [int(s) for s in strlist.split(', ')] + return torch.tensor(as_ints) + + def load_tsv_aligned_codes(filename): with open(filename, encoding='utf-8') as f: - components = [line.strip().split('\t') for line in f] + filepaths_and_text = [] base = os.path.dirname(filename) - def convert_string_list_to_tensor(strlist): - if strlist.startswith('['): - strlist = strlist[1:] - if strlist.endswith(']'): - strlist = strlist[:-1] - as_ints = [int(s) for s in strlist.split(', ')] - return torch.tensor(as_ints) - filepaths_and_text = [[os.path.join(base, f'{component[1]}'), component[0], convert_string_list_to_tensor(component[2])] for component in components] + bad_lines = 0 + for line in f: + components = line.strip().split('\t') + if len(components) < 2: + bad_lines += 1 + if bad_lines > 10: + print(f'{filename} contains 10+ bad entries. Failing. Sample last entry: {line}') + raise ValueError + filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])]) return filepaths_and_text