fix vall_e.emb.process

This commit is contained in:
mrq 2024-10-08 20:00:34 -05:00
parent 0656a762af
commit 52299127ab
2 changed files with 17 additions and 7 deletions

View File

@ -25,19 +25,22 @@ from .qnt import encode as quantize
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
def load_audio( path ):
def load_audio( path, device=None ):
waveform, sr = torchaudio.load( path )
if waveform.shape[0] > 1:
# mix channels
waveform = torch.mean(waveform, dim=0, keepdim=True)
if device is not None:
waveform = waveform.to(device)
return waveform, sr
def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
qnt = quantize(waveform.to(device=cfg.device), sr=sample_rate, device=cfg.device)
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ):
# encodec requires this to be on CPU for resampling
qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac":
state_dict = {
@ -72,14 +75,14 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
np.save(open(outpath, "wb"), state_dict)
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ):
if not jobs:
return
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
outpath, waveform, sample_rate, text, language = job
try:
process_job( outpath, waveform, sample_rate, text, language )
process_job( outpath, waveform, sample_rate, text, language, device )
except Exception as e:
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
if raise_exceptions:
@ -104,6 +107,7 @@ def process(
amp=False,
):
# prepare from args
cfg.device = device
cfg.set_audio_backend(audio_backend)
audio_extension = cfg.audio_backend_extension
@ -173,12 +177,14 @@ def process(
metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json')
if not metadata_path.exists():
missing["transcription"].append(str(metadata_path))
_logger.warning(f'Missing transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json')
continue
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
missing["transcription"].append(str(metadata_path))
_logger.warning(f'Failed to open transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json: {e}')
continue
if f'{group_name}/{speaker_id}' not in dataset:
@ -243,12 +249,12 @@ def process(
# processes audio files one at a time
if low_memory:
process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
jobs = []
# processes all audio files for a given speaker
if not low_memory:
process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
jobs = []
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))

View File

@ -247,6 +247,7 @@ def _load_encodec_model(device="cuda", levels=0):
model.sample_rate = cfg.sample_rate
model.normalize = cfg.inference.normalize
model.backend = "encodec"
model.device = device
return model
@ -274,6 +275,7 @@ def _load_vocos_model(device="cuda", levels=0):
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
model.sample_rate = cfg.sample_rate
model.backend = "vocos"
model.device = device
return model
@ -294,6 +296,7 @@ def _load_dac_model(device="cuda"):
model.backend = "dac"
model.model_type = kwargs["model_type"]
model.device = device
return model
@ -309,6 +312,7 @@ def _load_audiodec_model(device="cuda", model_name=None):
model.backend = "audiodec"
model.sample_rate = sample_rate
model.device = device
return model