final tweaks, hopefully
This commit is contained in:
parent
ffc334cf58
commit
caad7ee3c9
99
scripts/cleanup_dataset.py
Normal file
99
scripts/cleanup_dataset.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
input_dataset = "metadata"
|
||||
output_dataset = "metadata-cleaned"
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
for dataset_name in os.listdir(f'./{input_dataset}/'):
|
||||
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'):
|
||||
print("Is not dir:", f'./{input_dataset}/{dataset_name}/')
|
||||
continue
|
||||
|
||||
for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc=f"Processing speaker: {dataset_name}"):
|
||||
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'):
|
||||
print("Is not dir:", f'./{input_dataset}/{dataset_name}/{speaker_id}')
|
||||
continue
|
||||
|
||||
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||
|
||||
if not inpath.exists():
|
||||
continue
|
||||
|
||||
if outpath.exists():
|
||||
continue
|
||||
|
||||
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
|
||||
|
||||
try:
|
||||
in_metadata = json.loads(open(inpath, 'r', encoding='utf-8').read())
|
||||
except Exception as e:
|
||||
print("Failed to open metadata file:", inpath)
|
||||
continue
|
||||
|
||||
out_metadata = {}
|
||||
speaker_metadatas = {}
|
||||
|
||||
for filename, result in in_metadata.items():
|
||||
language = result["language"] if "language" in result else "en"
|
||||
out_metadata[filename] = {
|
||||
"segments": [],
|
||||
"language": language,
|
||||
"text": "",
|
||||
"start": 0,
|
||||
"end": 0,
|
||||
}
|
||||
segments = []
|
||||
text = []
|
||||
start = 0
|
||||
end = 0
|
||||
diarized = False
|
||||
|
||||
for segment in result["segments"]:
|
||||
# diarize split
|
||||
if "speaker" in segment:
|
||||
diarized = True
|
||||
speaker_id = segment["speaker"]
|
||||
if speaker_id not in speaker_metadatas:
|
||||
speaker_metadatas[speaker_id] = {}
|
||||
|
||||
if filename not in speaker_metadatas[speaker_id]:
|
||||
speaker_metadatas[speaker_id][filename] = {
|
||||
"segments": [],
|
||||
"language": language,
|
||||
"text": "",
|
||||
"start": 0,
|
||||
"end": 0,
|
||||
}
|
||||
|
||||
speaker_metadatas[speaker_id][filename]["segments"].append( segment )
|
||||
else:
|
||||
segments.append( segment )
|
||||
|
||||
text.append( segment["text"] )
|
||||
start = min( start, segment["start"] )
|
||||
end = max( end, segment["end"] )
|
||||
|
||||
out_metadata[filename]["segments"] = segments
|
||||
out_metadata[filename]["text"] = " ".join(text).strip()
|
||||
out_metadata[filename]["start"] = start
|
||||
out_metadata[filename]["end"] = end
|
||||
|
||||
if len(segments) == 0:
|
||||
del out_metadata[filename]
|
||||
|
||||
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
|
||||
|
||||
for speaker_id, out_metadata in speaker_metadatas.items():
|
||||
os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||
|
||||
open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
|
|
@ -8,26 +8,27 @@ from pathlib import Path
|
|||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
||||
|
||||
input_audio = "voice"
|
||||
# things that could be args
|
||||
input_audio = "voices"
|
||||
input_metadata = "metadata"
|
||||
output_dataset = "training-24K"
|
||||
device = "cuda"
|
||||
|
||||
slice = "auto"
|
||||
missing = {
|
||||
"transcription": [],
|
||||
"audio": []
|
||||
}
|
||||
|
||||
device = "cuda"
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||
for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
|
||||
continue
|
||||
|
||||
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
|
||||
for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"):
|
||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
|
||||
continue
|
||||
|
@ -36,24 +37,23 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
|
||||
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json')
|
||||
if not metadata_path.exists():
|
||||
#print("Does not exist:", metadata_path)
|
||||
missing["transcription"].append(str(metadata_path))
|
||||
continue
|
||||
|
||||
try:
|
||||
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
||||
except Exception as e:
|
||||
#print("Failed to load metadata:", metadata_path, e)
|
||||
missing["transcription"].append(str(metadata_path))
|
||||
continue
|
||||
|
||||
txts = []
|
||||
wavs = []
|
||||
|
||||
for filename in metadata.keys():
|
||||
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"]
|
||||
|
||||
for filename in sorted(metadata.keys()):
|
||||
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
|
||||
if not inpath.exists():
|
||||
#print("Does not exist:", inpath)
|
||||
missing["audio"].append(str(inpath))
|
||||
continue
|
||||
|
||||
|
@ -63,9 +63,8 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
waveform, sample_rate = None, None
|
||||
language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
|
||||
|
||||
if len(metadata[filename]["segments"]) == 0:
|
||||
id = pad(0, 4)
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||
if len(metadata[filename]["segments"]) == 0 or not use_slices:
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}')
|
||||
text = metadata[filename]["text"]
|
||||
|
||||
if len(text) == 0:
|
||||
|
@ -91,8 +90,10 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
sample_rate
|
||||
))
|
||||
else:
|
||||
i = 0
|
||||
for segment in metadata[filename]["segments"]:
|
||||
id = pad(segment['id'], 4)
|
||||
id = pad(i, 4)
|
||||
i = i + 1
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||
|
||||
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
|
||||
|
|
|
@ -7,30 +7,35 @@ import whisperx
|
|||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
device = "cuda"
|
||||
# should be args
|
||||
batch_size = 16
|
||||
device = "cuda"
|
||||
dtype = "float16"
|
||||
model_size = "large-v2"
|
||||
model_name = "large-v3"
|
||||
|
||||
input_audio = "voice"
|
||||
input_audio = "voices"
|
||||
output_dataset = "metadata"
|
||||
|
||||
skip_existing = True
|
||||
diarize = False
|
||||
|
||||
model = whisperx.load_model(model_size, device, compute_type=dtype)
|
||||
|
||||
#
|
||||
model = whisperx.load_model(model_name, device, compute_type=dtype)
|
||||
align_model, align_model_metadata, align_model_language = (None, None, None)
|
||||
if diarize:
|
||||
diarize_model = whisperx.DiarizationPipeline(device=device)
|
||||
else:
|
||||
diarize_model = None
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/')
|
||||
continue
|
||||
|
||||
for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"):
|
||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
||||
print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}')
|
||||
continue
|
||||
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
|
||||
|
@ -46,18 +51,29 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
if skip_existing and filename in metadata:
|
||||
continue
|
||||
|
||||
if ".json" in filename:
|
||||
continue
|
||||
|
||||
inpath = f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}'
|
||||
|
||||
if os.path.isdir(inpath):
|
||||
continue
|
||||
|
||||
metadata[filename] = {
|
||||
"segments": [],
|
||||
"language": "",
|
||||
"text": [],
|
||||
"text": "",
|
||||
"start": 0,
|
||||
"end": 0,
|
||||
}
|
||||
|
||||
audio = whisperx.load_audio(inpath)
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
language = result["language"]
|
||||
|
||||
if language[:2] not in ["ja"]:
|
||||
language = "en"
|
||||
|
||||
if align_model_language != language:
|
||||
tqdm.write(f'Loading language: {language}')
|
||||
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
|
||||
|
@ -68,12 +84,20 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
metadata[filename]["segments"] = result["segments"]
|
||||
metadata[filename]["language"] = language
|
||||
|
||||
if diarize_model is not None:
|
||||
diarize_segments = diarize_model(audio)
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
|
||||
text = []
|
||||
start = 0
|
||||
end = 0
|
||||
for segment in result["segments"]:
|
||||
id = len(text)
|
||||
text.append( segment["text"] )
|
||||
metadata[filename]["segments"][id]["id"] = id
|
||||
start = min( start, segment["start"] )
|
||||
end = max( end, segment["end"] )
|
||||
|
||||
metadata[filename]["text"] = " ".join(text).strip()
|
||||
metadata[filename]["start"] = start
|
||||
metadata[filename]["end"] = end
|
||||
|
||||
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
|
@ -33,6 +33,14 @@ class _Config:
|
|||
def cache_dir(self):
|
||||
return self.relpath / ".cache"
|
||||
|
||||
@property
|
||||
def data_dir(self):
|
||||
return self.relpath / "data"
|
||||
|
||||
@property
|
||||
def metadata_dir(self):
|
||||
return self.relpath / "metadata"
|
||||
|
||||
@property
|
||||
def ckpt_dir(self):
|
||||
return self.relpath / "ckpt"
|
||||
|
|
117
vall_e/data.py
117
vall_e/data.py
|
@ -85,46 +85,45 @@ def _calculate_durations( type="training" ):
|
|||
|
||||
@cfg.diskcache()
|
||||
def _load_paths(dataset, type="training"):
|
||||
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
|
||||
def _load_paths_from_metadata(dataset_name, type="training", validate=False):
|
||||
data_dir = dataset_name if cfg.dataset.use_hdf5 else cfg.data_dir / dataset_name
|
||||
|
||||
def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||
|
||||
def _validate( entry ):
|
||||
if "phones" not in entry or "duration" not in entry:
|
||||
return False
|
||||
phones = entry['phones']
|
||||
duration = entry['duration']
|
||||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
if type not in _total_durations:
|
||||
_total_durations[type] = 0
|
||||
_total_durations[type] += entry['duration']
|
||||
_total_durations[type] += duration
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
metadata_path = data_dir / "metadata.json"
|
||||
metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
|
||||
metadata = {}
|
||||
|
||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
||||
|
||||
|
||||
def key( dir, id ):
|
||||
if not cfg.dataset.use_hdf5:
|
||||
return data_dir / id
|
||||
|
||||
return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
|
||||
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
|
||||
|
||||
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
|
||||
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
path = str(path)
|
||||
if path[:2] != "./":
|
||||
path = f'./{path}'
|
||||
|
||||
res = path.replace(cfg.cfg_path, "")
|
||||
return res
|
||||
# to-do: better validation
|
||||
#print(path)
|
||||
return str(path)
|
||||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
@ -137,7 +136,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|||
_total_durations[type] += child.attrs['duration']
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
||||
|
@ -427,6 +426,9 @@ class Dataset(_Dataset):
|
|||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
if key not in cfg.hdf5:
|
||||
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
resps = cfg.hdf5[key]["audio"][:, :]
|
||||
|
||||
|
@ -752,6 +754,10 @@ def create_train_val_dataloader():
|
|||
|
||||
# parse dataset into better to sample metadata
|
||||
def create_dataset_metadata():
|
||||
# need to fix
|
||||
if True:
|
||||
return
|
||||
|
||||
cfg.dataset.validate = False
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
||||
|
@ -805,14 +811,19 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
symmap = get_phone_symmap()
|
||||
|
||||
root = cfg.cfg_path
|
||||
root = str(cfg.data_dir)
|
||||
metadata_root = str(cfg.metadata_dir)
|
||||
hf = cfg.hdf5
|
||||
|
||||
cfg.metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
name = "./" + str(dir)
|
||||
name = name .replace(root, "")
|
||||
metadata = {}
|
||||
name = str(dir)
|
||||
name = name.replace(root, "")
|
||||
|
||||
metadata_path = Path(f"{metadata_root}/{name}.json")
|
||||
|
||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
||||
|
||||
if not os.path.isdir(f'{root}/{name}/'):
|
||||
return
|
||||
|
@ -831,36 +842,38 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
continue
|
||||
|
||||
key = f'{type}/{name}/{id}'
|
||||
if key in hf:
|
||||
if skip_existing:
|
||||
continue
|
||||
del hf[key]
|
||||
|
||||
group = hf.create_group(key)
|
||||
if skip_existing and key in hf:
|
||||
continue
|
||||
|
||||
group = hf.create_group(key) if key not in hf else hf[key]
|
||||
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = name
|
||||
|
||||
metadata[id] = {}
|
||||
if id not in metadata:
|
||||
metadata[id] = {}
|
||||
|
||||
# audio
|
||||
if audios:
|
||||
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
if _get_quant_extension() == ".dac":
|
||||
if "audio" in group:
|
||||
del group["audio"]
|
||||
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
|
||||
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
duration = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
||||
metadata[id]["metadata"] = {
|
||||
"original_length": qnt["metadata"]["original_length"],
|
||||
"sample_rate": qnt["metadata"]["sample_rate"],
|
||||
"original_length": dac["metadata"]["original_length"],
|
||||
"sample_rate": dac["metadata"]["sample_rate"],
|
||||
}
|
||||
else:
|
||||
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
|
||||
duration = qnt.shape[0] / 75
|
||||
|
||||
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
|
||||
qnt = qnt.numpy().astype(np.int16)
|
||||
|
||||
if "audio" not in group:
|
||||
group.create_dataset('audio', data=qnt, compression='lzf')
|
||||
|
||||
group.attrs['duration'] = duration
|
||||
metadata[id]["duration"] = duration
|
||||
|
@ -870,52 +883,46 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
# text
|
||||
if texts:
|
||||
if _get_quant_extension() == ".json":
|
||||
if _get_phone_extension() == ".json":
|
||||
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||
content = json_metadata["phonemes"]
|
||||
txt = json_metadata["text"]
|
||||
else:
|
||||
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
|
||||
|
||||
"""
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
|
||||
for s in set(phones):
|
||||
if s not in symmap:
|
||||
symmap[s] = len(symmap.keys())
|
||||
|
||||
phn = [ symmap[s] for s in phones ]
|
||||
"""
|
||||
txt = ""
|
||||
|
||||
phn = cfg.tokenizer.encode("".join(content))
|
||||
phn = np.array(phn).astype(np.uint8)
|
||||
|
||||
if "text" in group:
|
||||
del group["text"]
|
||||
|
||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||
group.create_dataset('transcription', data=txt, compression='lzf', chunks=True)
|
||||
if "text" not in group:
|
||||
group.create_dataset('text', data=phn, compression='lzf')
|
||||
|
||||
group.attrs['phonemes'] = len(phn)
|
||||
group.attrs['transcription'] = txt
|
||||
|
||||
metadata[id]["phones"] = len(phn)
|
||||
metadata[id]["transcription"] = txt
|
||||
else:
|
||||
group.attrs['phonemes'] = 0
|
||||
metadata[id]["phones"] = 0
|
||||
except Exception as e:
|
||||
pass
|
||||
raise e
|
||||
#pass
|
||||
|
||||
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||
add( data_dir, type="training" )
|
||||
|
||||
# validation
|
||||
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
||||
for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
|
||||
add( data_dir, type="validation" )
|
||||
|
||||
# noise
|
||||
for data_dir in tqdm(cfg.dataset.noise, desc='Processing Noise'):
|
||||
for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'):
|
||||
add( data_dir, type="noise", texts=False )
|
||||
|
||||
# write symmap
|
||||
|
|
|
@ -340,7 +340,7 @@ def example_usage():
|
|||
def _load_quants(path) -> Tensor:
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
qnt = np.load(f'{path}.dac', allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
||||
return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants("./data/qnt")
|
||||
|
@ -350,7 +350,7 @@ def example_usage():
|
|||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt[:75*3, :].to(device),
|
||||
qnt.to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
|
@ -407,7 +407,7 @@ def example_usage():
|
|||
|
||||
frozen_params = set()
|
||||
for k in list(embeddings.keys()):
|
||||
if re.findall(r'_emb\.', k):
|
||||
if re.findall(r'_emb.', k):
|
||||
frozen_params.add(k)
|
||||
else:
|
||||
del embeddings[k]
|
||||
|
|
Loading…
Reference in New Issue
Block a user