fix vall_e.emb.process
This commit is contained in:
parent
0656a762af
commit
52299127ab
@ -25,19 +25,22 @@ from .qnt import encode as quantize
|
|||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
def load_audio( path ):
|
def load_audio( path, device=None ):
|
||||||
waveform, sr = torchaudio.load( path )
|
waveform, sr = torchaudio.load( path )
|
||||||
if waveform.shape[0] > 1:
|
if waveform.shape[0] > 1:
|
||||||
# mix channels
|
# mix channels
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
if device is not None:
|
||||||
|
waveform = waveform.to(device)
|
||||||
return waveform, sr
|
return waveform, sr
|
||||||
|
|
||||||
def process_items( items, stride=0, stride_offset=0 ):
|
def process_items( items, stride=0, stride_offset=0 ):
|
||||||
items = sorted( items )
|
items = sorted( items )
|
||||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||||
|
|
||||||
def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
|
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ):
|
||||||
qnt = quantize(waveform.to(device=cfg.device), sr=sample_rate, device=cfg.device)
|
# encodec requires this to be on CPU for resampling
|
||||||
|
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||||
|
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
state_dict = {
|
state_dict = {
|
||||||
@ -72,14 +75,14 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
|
|||||||
|
|
||||||
np.save(open(outpath, "wb"), state_dict)
|
np.save(open(outpath, "wb"), state_dict)
|
||||||
|
|
||||||
def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ):
|
||||||
if not jobs:
|
if not jobs:
|
||||||
return
|
return
|
||||||
|
|
||||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||||
outpath, waveform, sample_rate, text, language = job
|
outpath, waveform, sample_rate, text, language = job
|
||||||
try:
|
try:
|
||||||
process_job( outpath, waveform, sample_rate, text, language )
|
process_job( outpath, waveform, sample_rate, text, language, device )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||||
if raise_exceptions:
|
if raise_exceptions:
|
||||||
@ -104,6 +107,7 @@ def process(
|
|||||||
amp=False,
|
amp=False,
|
||||||
):
|
):
|
||||||
# prepare from args
|
# prepare from args
|
||||||
|
cfg.device = device
|
||||||
cfg.set_audio_backend(audio_backend)
|
cfg.set_audio_backend(audio_backend)
|
||||||
audio_extension = cfg.audio_backend_extension
|
audio_extension = cfg.audio_backend_extension
|
||||||
|
|
||||||
@ -173,12 +177,14 @@ def process(
|
|||||||
metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json')
|
metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json')
|
||||||
if not metadata_path.exists():
|
if not metadata_path.exists():
|
||||||
missing["transcription"].append(str(metadata_path))
|
missing["transcription"].append(str(metadata_path))
|
||||||
|
_logger.warning(f'Missing transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
missing["transcription"].append(str(metadata_path))
|
missing["transcription"].append(str(metadata_path))
|
||||||
|
_logger.warning(f'Failed to open transcription metadata: ./{input_audio}/{group_name}/{speaker_id}/whisper.json: {e}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if f'{group_name}/{speaker_id}' not in dataset:
|
if f'{group_name}/{speaker_id}' not in dataset:
|
||||||
@ -243,12 +249,12 @@ def process(
|
|||||||
|
|
||||||
# processes audio files one at a time
|
# processes audio files one at a time
|
||||||
if low_memory:
|
if low_memory:
|
||||||
process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
# processes all audio files for a given speaker
|
# processes all audio files for a given speaker
|
||||||
if not low_memory:
|
if not low_memory:
|
||||||
process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||||
|
|||||||
@ -247,6 +247,7 @@ def _load_encodec_model(device="cuda", levels=0):
|
|||||||
model.sample_rate = cfg.sample_rate
|
model.sample_rate = cfg.sample_rate
|
||||||
model.normalize = cfg.inference.normalize
|
model.normalize = cfg.inference.normalize
|
||||||
model.backend = "encodec"
|
model.backend = "encodec"
|
||||||
|
model.device = device
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -274,6 +275,7 @@ def _load_vocos_model(device="cuda", levels=0):
|
|||||||
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
||||||
model.sample_rate = cfg.sample_rate
|
model.sample_rate = cfg.sample_rate
|
||||||
model.backend = "vocos"
|
model.backend = "vocos"
|
||||||
|
model.device = device
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -294,6 +296,7 @@ def _load_dac_model(device="cuda"):
|
|||||||
|
|
||||||
model.backend = "dac"
|
model.backend = "dac"
|
||||||
model.model_type = kwargs["model_type"]
|
model.model_type = kwargs["model_type"]
|
||||||
|
model.device = device
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -309,6 +312,7 @@ def _load_audiodec_model(device="cuda", model_name=None):
|
|||||||
|
|
||||||
model.backend = "audiodec"
|
model.backend = "audiodec"
|
||||||
model.sample_rate = sample_rate
|
model.sample_rate = sample_rate
|
||||||
|
model.device = device
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user