fix vall_e.emb.process
This commit is contained in:
parent
0656a762af
commit
52299127ab
@ -25,19 +25,22 @@ from .qnt import encode as quantize
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
def load_audio( path ):
|
||||
def load_audio( path, device=None ):
|
||||
waveform, sr = torchaudio.load( path )
|
||||
if waveform.shape[0] > 1:
|
||||
# mix channels
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
if device is not None:
|
||||
waveform = waveform.to(device)
|
||||
return waveform, sr
|
||||
|
||||
def process_items( items, stride=0, stride_offset=0 ):
|
||||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
|
||||
qnt = quantize(waveform.to(device=cfg.device), sr=sample_rate, device=cfg.device)
|
||||
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ):
|
||||
# encodec requires this to be on CPU for resampling
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
state_dict = {
|
||||
@ -72,14 +75,14 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
|
||||
|
||||
np.save(open(outpath, "wb"), state_dict)
|
||||
|
||||
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
||||
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ):
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||
outpath, waveform, sample_rate, text, language = job
|
||||
try:
|
||||
process_job( outpath, waveform, sample_rate, text, language )
|
||||
process_job( outpath, waveform, sample_rate, text, language, device )
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||
if raise_exceptions:
|
||||
@ -104,6 +107,7 @@ def process(
|
||||
amp=False,
|
||||
):
|
||||
# prepare from args
|
||||
cfg.device = device
|
||||
cfg.set_audio_backend(audio_backend)
|
||||
audio_extension = cfg.audio_backend_extension
|
||||
|
||||
@ -173,12 +177,14 @@ def process(
|
||||
metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json')
|
||||
if not metadata_path.exists():
|
||||
missing["transcription"].append(str(metadata_path))
|
||||
_logger.warning(f'Missing transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json')
|
||||
continue
|
||||
|
||||
try:
|
||||
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
||||
except Exception as e:
|
||||
missing["transcription"].append(str(metadata_path))
|
||||
_logger.warning(f'Failed to open transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json: {e}')
|
||||
continue
|
||||
|
||||
if f'{group_name}/{speaker_id}' not in dataset:
|
||||
@ -243,12 +249,12 @@ def process(
|
||||
|
||||
# processes audio files one at a time
|
||||
if low_memory:
|
||||
process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
||||
jobs = []
|
||||
|
||||
# processes all audio files for a given speaker
|
||||
if not low_memory:
|
||||
process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
||||
jobs = []
|
||||
|
||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
|
||||
@ -247,6 +247,7 @@ def _load_encodec_model(device="cuda", levels=0):
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.normalize = cfg.inference.normalize
|
||||
model.backend = "encodec"
|
||||
model.device = device
|
||||
|
||||
return model
|
||||
|
||||
@ -274,6 +275,7 @@ def _load_vocos_model(device="cuda", levels=0):
|
||||
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.backend = "vocos"
|
||||
model.device = device
|
||||
|
||||
return model
|
||||
|
||||
@ -294,6 +296,7 @@ def _load_dac_model(device="cuda"):
|
||||
|
||||
model.backend = "dac"
|
||||
model.model_type = kwargs["model_type"]
|
||||
model.device = device
|
||||
|
||||
return model
|
||||
|
||||
@ -309,6 +312,7 @@ def _load_audiodec_model(device="cuda", model_name=None):
|
||||
|
||||
model.backend = "audiodec"
|
||||
model.sample_rate = sample_rate
|
||||
model.device = device
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user