could have sworn i had 'vall_e.emb.process --dtype' working, also possible RAM optimization so I can stop locking up my server when firing four encoding processes
This commit is contained in:
parent
47eb498046
commit
ed94b261dc
|
@ -25,11 +25,13 @@ from .qnt import encode as quantize, encode_batch as quantize_batch
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
def load_audio( path, device=None ):
|
def load_audio( path, device=None, dtype=None ):
|
||||||
waveform, sr = torchaudio.load( path )
|
waveform, sr = torchaudio.load( path )
|
||||||
if waveform.shape[0] > 1:
|
if waveform.shape[0] > 1:
|
||||||
# mix channels
|
# mix channels
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
if dtype is not None:
|
||||||
|
waveform = waveform.to(dtype)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
waveform = waveform.to(device)
|
waveform = waveform.to(device)
|
||||||
return waveform, sr
|
return waveform, sr
|
||||||
|
@ -190,7 +192,7 @@ def process(
|
||||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||||
cfg.inference.amp = amp # False
|
cfg.inference.amp = amp # False
|
||||||
|
|
||||||
dtype = None
|
dtype = cfg.inference.dtype if not amp else None
|
||||||
|
|
||||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||||
|
|
||||||
|
@ -245,7 +247,7 @@ def process(
|
||||||
if outpath.exists():
|
if outpath.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waveform, sample_rate = load_audio( inpath )
|
waveform, sample_rate = load_audio( inpath, dtype=dtype )
|
||||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||||
|
|
||||||
process_job(outpath, waveform, sample_rate)
|
process_job(outpath, waveform, sample_rate)
|
||||||
|
@ -296,7 +298,7 @@ def process(
|
||||||
|
|
||||||
# audio not already loaded, load it
|
# audio not already loaded, load it
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = load_audio( inpath )
|
waveform, sample_rate = load_audio( inpath, dtype=dtype )
|
||||||
|
|
||||||
if max_duration and waveform.shape[-1] / sample_rate > max_duration:
|
if max_duration and waveform.shape[-1] / sample_rate > max_duration:
|
||||||
continue
|
continue
|
||||||
|
@ -325,7 +327,7 @@ def process(
|
||||||
|
|
||||||
# audio not already loaded, load it
|
# audio not already loaded, load it
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = load_audio( inpath )
|
waveform, sample_rate = load_audio( inpath, dtype=dtype )
|
||||||
|
|
||||||
start = int((segment['start']-0.05) * sample_rate)
|
start = int((segment['start']-0.05) * sample_rate)
|
||||||
end = int((segment['end']+0.5) * sample_rate)
|
end = int((segment['end']+0.5) * sample_rate)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user