ugh
This commit is contained in:
parent
eac353cd0b
commit
613024ec0d
|
@ -22,19 +22,19 @@ from .qnt import encode as quantize
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
def load_audio( path, device ):
|
def load_audio( path ):
|
||||||
waveform, sr = torchaudio.load( path )
|
waveform, sr = torchaudio.load( path )
|
||||||
if waveform.shape[0] > 1:
|
if waveform.shape[0] > 1:
|
||||||
# mix channels
|
# mix channels
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
return waveform.to(device=device), sr
|
return waveform, sr
|
||||||
|
|
||||||
def process_items( items, stride=0, stride_offset=0 ):
|
def process_items( items, stride=0, stride_offset=0 ):
|
||||||
items = sorted( items )
|
items = sorted( items )
|
||||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||||
|
|
||||||
def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
|
def process_job( outpath, waveform, sample_rate, text=None, language="en" ):
|
||||||
qnt = quantize(waveform, sr=sample_rate, device=waveform.device)
|
qnt = quantize(waveform.to(device=cfg.device), sr=sample_rate, device=cfg.device)
|
||||||
|
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
@ -156,7 +156,7 @@ def process(
|
||||||
if outpath.exists():
|
if outpath.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waveform, sample_rate = load_audio( inpath, device )
|
waveform, sample_rate = load_audio( inpath )
|
||||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||||
|
|
||||||
process_job(outpath, waveform, sample_rate)
|
process_job(outpath, waveform, sample_rate)
|
||||||
|
@ -202,7 +202,7 @@ def process(
|
||||||
|
|
||||||
# audio not already loaded, load it
|
# audio not already loaded, load it
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = load_audio( inpath, device )
|
waveform, sample_rate = load_audio( inpath )
|
||||||
|
|
||||||
jobs.append(( outpath, waveform, sample_rate, text, language ))
|
jobs.append(( outpath, waveform, sample_rate, text, language ))
|
||||||
else:
|
else:
|
||||||
|
@ -219,7 +219,7 @@ def process(
|
||||||
|
|
||||||
# audio not already loaded, load it
|
# audio not already loaded, load it
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = load_audio( inpath, device )
|
waveform, sample_rate = load_audio( inpath )
|
||||||
|
|
||||||
start = int(segment['start'] * sample_rate)
|
start = int(segment['start'] * sample_rate)
|
||||||
end = int(segment['end'] * sample_rate)
|
end = int(segment['end'] * sample_rate)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user