re-added amp encoding/decoding for audio, possible bad idea to ignore using amp instead if requested
This commit is contained in:
parent
7592befc53
commit
299cc88821
@ -38,9 +38,9 @@ def process_items( items, stride=0, stride_offset=0 ):
|
|||||||
items = sorted( items )
|
items = sorted( items )
|
||||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||||
|
|
||||||
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ):
|
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda", dtype=None ):
|
||||||
# encodec requires this to be on CPU for resampling
|
# encodec requires this to be on CPU for resampling
|
||||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype)
|
||||||
|
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
state_dict = {
|
state_dict = {
|
||||||
@ -75,10 +75,13 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
|
|||||||
|
|
||||||
np.save(open(outpath, "wb"), state_dict)
|
np.save(open(outpath, "wb"), state_dict)
|
||||||
|
|
||||||
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ):
|
||||||
if not jobs:
|
if not jobs:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# sort to avoid egregious padding
|
||||||
|
jobs = sorted(jobs, key=lambda x: x[1].shape[-1], reverse=True)
|
||||||
|
|
||||||
buffer = []
|
buffer = []
|
||||||
batches = []
|
batches = []
|
||||||
|
|
||||||
@ -88,7 +91,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
|
|||||||
batches.append(buffer)
|
batches.append(buffer)
|
||||||
buffer = []
|
buffer = []
|
||||||
|
|
||||||
if len(buffer) >= batch_size:
|
if buffer:
|
||||||
batches.append(buffer)
|
batches.append(buffer)
|
||||||
buffer = []
|
buffer = []
|
||||||
|
|
||||||
@ -101,7 +104,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
|
|||||||
srs.append(sample_rate)
|
srs.append(sample_rate)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
codes = quantize_batch(wavs, sr=srs, device=device)
|
codes = quantize_batch(wavs, sr=srs, device=device, dtype=dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||||
if raise_exceptions:
|
if raise_exceptions:
|
||||||
@ -142,18 +145,18 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
|
|||||||
|
|
||||||
np.save(open(outpath, "wb"), state_dict)
|
np.save(open(outpath, "wb"), state_dict)
|
||||||
|
|
||||||
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ):
|
||||||
if not jobs:
|
if not jobs:
|
||||||
return
|
return
|
||||||
|
|
||||||
# batch things
|
# batch things
|
||||||
if batch_size > 1:
|
if batch_size > 1:
|
||||||
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype )
|
||||||
|
|
||||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||||
outpath, waveform, sample_rate, text, language = job
|
outpath, waveform, sample_rate, text, language = job
|
||||||
try:
|
try:
|
||||||
process_job( outpath, waveform, sample_rate, text, language, device )
|
process_job( outpath, waveform, sample_rate, text, language, device, dtype=dtype )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||||
if raise_exceptions:
|
if raise_exceptions:
|
||||||
@ -186,6 +189,8 @@ def process(
|
|||||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||||
cfg.inference.amp = amp # False
|
cfg.inference.amp = amp # False
|
||||||
|
|
||||||
|
dtype = None
|
||||||
|
|
||||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||||
|
|
||||||
# to-do: make this also prepared from args
|
# to-do: make this also prepared from args
|
||||||
@ -334,12 +339,12 @@ def process(
|
|||||||
|
|
||||||
# processes audio files one at a time
|
# processes audio files one at a time
|
||||||
if low_memory:
|
if low_memory:
|
||||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size )
|
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
# processes all audio files for a given speaker
|
# processes all audio files for a given speaker
|
||||||
if not low_memory:
|
if not low_memory:
|
||||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||||
|
|||||||
@ -45,7 +45,7 @@ except Exception as e:
|
|||||||
_logger.warning(str(e))
|
_logger.warning(str(e))
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_encodec_model(device="cuda", levels=0):
|
def _load_encodec_model(device="cuda", dtype=None, levels=0):
|
||||||
assert cfg.sample_rate == 24_000
|
assert cfg.sample_rate == 24_000
|
||||||
|
|
||||||
if not levels:
|
if not levels:
|
||||||
@ -67,6 +67,9 @@ def _load_encodec_model(device="cuda", levels=0):
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
model = model.to(dtype)
|
||||||
|
|
||||||
# extra metadata
|
# extra metadata
|
||||||
model.bandwidth_id = bandwidth_id
|
model.bandwidth_id = bandwidth_id
|
||||||
model.normalize = cfg.inference.normalize
|
model.normalize = cfg.inference.normalize
|
||||||
@ -75,7 +78,7 @@ def _load_encodec_model(device="cuda", levels=0):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_vocos_model(device="cuda", levels=0):
|
def _load_vocos_model(device="cuda", dtype=None, levels=0):
|
||||||
assert cfg.sample_rate == 24_000
|
assert cfg.sample_rate == 24_000
|
||||||
|
|
||||||
if not levels:
|
if not levels:
|
||||||
@ -85,6 +88,9 @@ def _load_vocos_model(device="cuda", levels=0):
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
model = model.to(dtype)
|
||||||
|
|
||||||
# too lazy to un-if ladder this shit
|
# too lazy to un-if ladder this shit
|
||||||
bandwidth_id = 2
|
bandwidth_id = 2
|
||||||
if levels == 2:
|
if levels == 2:
|
||||||
@ -101,7 +107,7 @@ def _load_vocos_model(device="cuda", levels=0):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_dac_model(device="cuda"):
|
def _load_dac_model(device="cuda", dtype=None):
|
||||||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||||
if cfg.sample_rate == 44_100:
|
if cfg.sample_rate == 44_100:
|
||||||
@ -115,17 +121,25 @@ def _load_dac_model(device="cuda"):
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
model = model.to(dtype)
|
||||||
|
|
||||||
model.backend = "dac"
|
model.backend = "dac"
|
||||||
model.model_type = kwargs["model_type"]
|
model.model_type = kwargs["model_type"]
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_nemo_model(device="cuda", model_name=None):
|
def _load_nemo_model(device="cuda", dtype=None, model_name=None):
|
||||||
if not model_name:
|
if not model_name:
|
||||||
model_name = "nvidia/audio-codec-44khz"
|
model_name = "nvidia/audio-codec-44khz"
|
||||||
|
|
||||||
model = AudioCodecModel.from_pretrained(model_name).to(device).eval()
|
model = AudioCodecModel.from_pretrained(model_name)
|
||||||
|
model = model.to(device)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
model = model.to(dtype)
|
||||||
|
|
||||||
model.backend = "nemo"
|
model.backend = "nemo"
|
||||||
|
|
||||||
@ -133,20 +147,23 @@ def _load_nemo_model(device="cuda", model_name=None):
|
|||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_model(device="cuda", backend=None):
|
def _load_model(device="cuda", backend=None, dtype=None):
|
||||||
if not backend:
|
if not backend:
|
||||||
backend = cfg.audio_backend
|
backend = cfg.audio_backend
|
||||||
|
|
||||||
if backend == "nemo":
|
if cfg.inference.amp:
|
||||||
return _load_nemo_model(device)
|
dtype = None
|
||||||
if backend == "audiodec":
|
|
||||||
return _load_audiodec_model(device)
|
|
||||||
if backend == "dac":
|
|
||||||
return _load_dac_model(device)
|
|
||||||
if backend == "vocos":
|
|
||||||
return _load_vocos_model(device)
|
|
||||||
|
|
||||||
return _load_encodec_model(device)
|
if backend == "nemo":
|
||||||
|
return _load_nemo_model(device, dtype=dtype)
|
||||||
|
if backend == "audiodec":
|
||||||
|
return _load_audiodec_model(device, dtype=dtype)
|
||||||
|
if backend == "dac":
|
||||||
|
return _load_dac_model(device, dtype=dtype)
|
||||||
|
if backend == "vocos":
|
||||||
|
return _load_vocos_model(device, dtype=dtype)
|
||||||
|
|
||||||
|
return _load_encodec_model(device, dtype=dtype)
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
_load_model.cache_clear()
|
_load_model.cache_clear()
|
||||||
@ -154,7 +171,7 @@ def unload_model():
|
|||||||
|
|
||||||
# to-do: clean up this mess
|
# to-do: clean up this mess
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
|
||||||
# upcast so it won't whine
|
# upcast so it won't whine
|
||||||
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
|
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||||
codes = codes.to(torch.int32)
|
codes = codes.to(torch.int32)
|
||||||
@ -174,7 +191,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
|||||||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||||
|
|
||||||
# load the model
|
# load the model
|
||||||
model = _load_model(device)
|
model = _load_model(device, dtype=dtype)
|
||||||
# move to device
|
# move to device
|
||||||
codes = codes.to( device=device )
|
codes = codes.to( device=device )
|
||||||
|
|
||||||
@ -231,7 +248,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
|||||||
return wav, cfg.sample_rate
|
return wav, cfg.sample_rate
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def decode_batch(codes: list[Tensor], device="cuda"):
|
def decode_batch(codes: list[Tensor], device="cuda", dtype=None):
|
||||||
# transpose if needed
|
# transpose if needed
|
||||||
for i, code in enumerate(codes):
|
for i, code in enumerate(codes):
|
||||||
if code.shape[0] < code.shape[1]:
|
if code.shape[0] < code.shape[1]:
|
||||||
@ -254,7 +271,7 @@ def decode_batch(codes: list[Tensor], device="cuda"):
|
|||||||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||||
|
|
||||||
# load the model
|
# load the model
|
||||||
model = _load_model(device)
|
model = _load_model(device, dtype=dtype)
|
||||||
# move to device
|
# move to device
|
||||||
codes = codes.to( device=device )
|
codes = codes.to( device=device )
|
||||||
|
|
||||||
@ -280,7 +297,7 @@ def _replace_file_extension(path, suffix):
|
|||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
|
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, return_metadata=True, window_duration=None):
|
||||||
# expand if 1D
|
# expand if 1D
|
||||||
if wav.dim() < 2:
|
if wav.dim() < 2:
|
||||||
wav = wav.unsqueeze(0)
|
wav = wav.unsqueeze(0)
|
||||||
@ -288,12 +305,16 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
|||||||
if wav.dim() < 3:
|
if wav.dim() < 3:
|
||||||
wav = wav.unsqueeze(0)
|
wav = wav.unsqueeze(0)
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
wav = wav.to(dtype)
|
||||||
|
|
||||||
# cringe assert
|
# cringe assert
|
||||||
assert wav.shape[0] == 1, f'Batch encoding is unsupported with vanilla encode()'
|
assert wav.shape[0] == 1, f'Batch encoding is unsupported with vanilla encode()'
|
||||||
|
|
||||||
|
model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype )
|
||||||
|
|
||||||
# DAC uses a different pathway
|
# DAC uses a different pathway
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
model = _load_dac_model( device )
|
|
||||||
signal = AudioSignal(wav, sample_rate=sr)
|
signal = AudioSignal(wav, sample_rate=sr)
|
||||||
|
|
||||||
artifact = model.compress(signal, win_duration=window_duration, verbose=False)
|
artifact = model.compress(signal, win_duration=window_duration, verbose=False)
|
||||||
@ -307,23 +328,22 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
|||||||
|
|
||||||
# NeMo uses a different pathway
|
# NeMo uses a different pathway
|
||||||
if cfg.audio_backend == "nemo":
|
if cfg.audio_backend == "nemo":
|
||||||
model = _load_nemo_model( device )
|
|
||||||
|
|
||||||
wav = wav.to(device)[:, 0, :]
|
wav = wav.to(device)[:, 0, :]
|
||||||
l = torch.tensor([w.shape[0] for w in wav]).to(device)
|
l = torch.tensor([w.shape[0] for w in wav]).to(device)
|
||||||
codes, lens = model.encode(audio=wav, audio_len=l)
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
|
codes, lens = model.encode(audio=wav, audio_len=l)
|
||||||
# to-do: unpad
|
# to-do: unpad
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||||
if cfg.audio_backend in ["encodec", "vocos"]:
|
if cfg.audio_backend in ["encodec", "vocos"]:
|
||||||
model = _load_encodec_model(device)
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
codes = model.encode(wav)
|
codes = model.encode(wav)
|
||||||
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
|
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda" ):
|
def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda", dtype=None ):
|
||||||
# expand as list
|
# expand as list
|
||||||
if not isinstance(sr, list):
|
if not isinstance(sr, list):
|
||||||
sr = [sr] * len(wavs)
|
sr = [sr] * len(wavs)
|
||||||
@ -351,11 +371,16 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
|
|||||||
#
|
#
|
||||||
wav = wav.to(device)
|
wav = wav.to(device)
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
wav = wav.to(dtype)
|
||||||
|
|
||||||
|
model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype )
|
||||||
|
|
||||||
# NeMo uses a different pathway
|
# NeMo uses a different pathway
|
||||||
if cfg.audio_backend == "nemo":
|
if cfg.audio_backend == "nemo":
|
||||||
model = _load_nemo_model( device )
|
|
||||||
wav = wav.to(device)[:, 0, :]
|
wav = wav.to(device)[:, 0, :]
|
||||||
codes, code_lens = model.encode(audio=wav, audio_len=lens)
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
|
codes, code_lens = model.encode(audio=wav, audio_len=lens)
|
||||||
return [ code[:, :l] for code, l in zip( codes, code_lens ) ]
|
return [ code[:, :l] for code, l in zip( codes, code_lens ) ]
|
||||||
|
|
||||||
# can't be assed to implement
|
# can't be assed to implement
|
||||||
@ -364,8 +389,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
|
|||||||
|
|
||||||
# naively encode
|
# naively encode
|
||||||
if cfg.audio_backend in ["encodec", "vocos"]:
|
if cfg.audio_backend in ["encodec", "vocos"]:
|
||||||
model = _load_encodec_model(device)
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
codes = model.encode(wav)
|
codes = model.encode(wav)
|
||||||
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
|
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
|
||||||
|
|
||||||
return [ code[:, :l * cfg.dataset.frames_per_second // cfg.sample_rate] for code, l in zip(codes, lens) ]
|
return [ code[:, :l * cfg.dataset.frames_per_second // cfg.sample_rate] for code, l in zip(codes, lens) ]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user