re-added amp encoding/decoding for audio, possible bad idea to ignore using amp instead if requested

This commit is contained in:
mrq 2025-02-05 21:55:06 -06:00
parent 7592befc53
commit 299cc88821
2 changed files with 71 additions and 41 deletions

View File

@ -38,9 +38,9 @@ def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ):
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda", dtype=None ):
# encodec requires this to be on CPU for resampling
qnt = quantize(waveform, sr=sample_rate, device=device)
qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype)
if cfg.audio_backend == "dac":
state_dict = {
@ -75,10 +75,13 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
np.save(open(outpath, "wb"), state_dict)
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ):
if not jobs:
return
# sort to avoid egregious padding
jobs = sorted(jobs, key=lambda x: x[1].shape[-1], reverse=True)
buffer = []
batches = []
@ -88,7 +91,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
batches.append(buffer)
buffer = []
if len(buffer) >= batch_size:
if buffer:
batches.append(buffer)
buffer = []
@ -101,7 +104,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
srs.append(sample_rate)
try:
codes = quantize_batch(wavs, sr=srs, device=device)
codes = quantize_batch(wavs, sr=srs, device=device, dtype=dtype)
except Exception as e:
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
if raise_exceptions:
@ -142,18 +145,18 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
np.save(open(outpath, "wb"), state_dict)
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ):
if not jobs:
return
# batch things
if batch_size > 1:
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size )
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype )
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
outpath, waveform, sample_rate, text, language = job
try:
process_job( outpath, waveform, sample_rate, text, language, device )
process_job( outpath, waveform, sample_rate, text, language, device, dtype=dtype )
except Exception as e:
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
if raise_exceptions:
@ -186,6 +189,8 @@ def process(
cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False
dtype = None
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
# to-do: make this also prepared from args
@ -334,12 +339,12 @@ def process(
# processes audio files one at a time
if low_memory:
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size )
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
jobs = []
# processes all audio files for a given speaker
if not low_memory:
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size )
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
jobs = []
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))

View File

@ -45,7 +45,7 @@ except Exception as e:
_logger.warning(str(e))
@cache
def _load_encodec_model(device="cuda", levels=0):
def _load_encodec_model(device="cuda", dtype=None, levels=0):
assert cfg.sample_rate == 24_000
if not levels:
@ -67,6 +67,9 @@ def _load_encodec_model(device="cuda", levels=0):
model = model.to(device)
model = model.eval()
if dtype is not None:
model = model.to(dtype)
# extra metadata
model.bandwidth_id = bandwidth_id
model.normalize = cfg.inference.normalize
@ -75,7 +78,7 @@ def _load_encodec_model(device="cuda", levels=0):
return model
@cache
def _load_vocos_model(device="cuda", levels=0):
def _load_vocos_model(device="cuda", dtype=None, levels=0):
assert cfg.sample_rate == 24_000
if not levels:
@ -85,6 +88,9 @@ def _load_vocos_model(device="cuda", levels=0):
model = model.to(device)
model = model.eval()
if dtype is not None:
model = model.to(dtype)
# too lazy to un-if ladder this shit
bandwidth_id = 2
if levels == 2:
@ -101,7 +107,7 @@ def _load_vocos_model(device="cuda", levels=0):
return model
@cache
def _load_dac_model(device="cuda"):
def _load_dac_model(device="cuda", dtype=None):
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg.sample_rate == 44_100:
@ -115,17 +121,25 @@ def _load_dac_model(device="cuda"):
model = model.to(device)
model = model.eval()
if dtype is not None:
model = model.to(dtype)
model.backend = "dac"
model.model_type = kwargs["model_type"]
return model
@cache
def _load_nemo_model(device="cuda", model_name=None):
def _load_nemo_model(device="cuda", dtype=None, model_name=None):
if not model_name:
model_name = "nvidia/audio-codec-44khz"
model = AudioCodecModel.from_pretrained(model_name).to(device).eval()
model = AudioCodecModel.from_pretrained(model_name)
model = model.to(device)
model = model.eval()
if dtype is not None:
model = model.to(dtype)
model.backend = "nemo"
@ -133,20 +147,23 @@ def _load_nemo_model(device="cuda", model_name=None):
@cache
def _load_model(device="cuda", backend=None):
def _load_model(device="cuda", backend=None, dtype=None):
if not backend:
backend = cfg.audio_backend
if backend == "nemo":
return _load_nemo_model(device)
if backend == "audiodec":
return _load_audiodec_model(device)
if backend == "dac":
return _load_dac_model(device)
if backend == "vocos":
return _load_vocos_model(device)
if cfg.inference.amp:
dtype = None
return _load_encodec_model(device)
if backend == "nemo":
return _load_nemo_model(device, dtype=dtype)
if backend == "audiodec":
return _load_audiodec_model(device, dtype=dtype)
if backend == "dac":
return _load_dac_model(device, dtype=dtype)
if backend == "vocos":
return _load_vocos_model(device, dtype=dtype)
return _load_encodec_model(device, dtype=dtype)
def unload_model():
_load_model.cache_clear()
@ -154,7 +171,7 @@ def unload_model():
# to-do: clean up this mess
@torch.inference_mode()
def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
# upcast so it won't whine
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
codes = codes.to(torch.int32)
@ -174,7 +191,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
# load the model
model = _load_model(device)
model = _load_model(device, dtype=dtype)
# move to device
codes = codes.to( device=device )
@ -231,7 +248,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
return wav, cfg.sample_rate
@torch.inference_mode()
def decode_batch(codes: list[Tensor], device="cuda"):
def decode_batch(codes: list[Tensor], device="cuda", dtype=None):
# transpose if needed
for i, code in enumerate(codes):
if code.shape[0] < code.shape[1]:
@ -254,7 +271,7 @@ def decode_batch(codes: list[Tensor], device="cuda"):
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
# load the model
model = _load_model(device)
model = _load_model(device, dtype=dtype)
# move to device
codes = codes.to( device=device )
@ -280,7 +297,7 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, return_metadata=True, window_duration=None):
# expand if 1D
if wav.dim() < 2:
wav = wav.unsqueeze(0)
@ -288,12 +305,16 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
if wav.dim() < 3:
wav = wav.unsqueeze(0)
if dtype is not None:
wav = wav.to(dtype)
# cringe assert
assert wav.shape[0] == 1, f'Batch encoding is unsupported with vanilla encode()'
model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype )
# DAC uses a different pathway
if cfg.audio_backend == "dac":
model = _load_dac_model( device )
signal = AudioSignal(wav, sample_rate=sr)
artifact = model.compress(signal, win_duration=window_duration, verbose=False)
@ -307,23 +328,22 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
# NeMo uses a different pathway
if cfg.audio_backend == "nemo":
model = _load_nemo_model( device )
wav = wav.to(device)[:, 0, :]
l = torch.tensor([w.shape[0] for w in wav]).to(device)
codes, lens = model.encode(audio=wav, audio_len=l)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
codes, lens = model.encode(audio=wav, audio_len=l)
# to-do: unpad
return codes
# vocos does not encode wavs to encodecs, so just use normal encodec
if cfg.audio_backend in ["encodec", "vocos"]:
model = _load_encodec_model(device)
codes = model.encode(wav)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
codes = model.encode(wav)
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
return codes
@torch.inference_mode()
def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda" ):
def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda", dtype=None ):
# expand as list
if not isinstance(sr, list):
sr = [sr] * len(wavs)
@ -351,11 +371,16 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
#
wav = wav.to(device)
if dtype is not None:
wav = wav.to(dtype)
model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype )
# NeMo uses a different pathway
if cfg.audio_backend == "nemo":
model = _load_nemo_model( device )
wav = wav.to(device)[:, 0, :]
codes, code_lens = model.encode(audio=wav, audio_len=lens)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
codes, code_lens = model.encode(audio=wav, audio_len=lens)
return [ code[:, :l] for code, l in zip( codes, code_lens ) ]
# can't be assed to implement
@ -364,8 +389,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
# naively encode
if cfg.audio_backend in ["encodec", "vocos"]:
model = _load_encodec_model(device)
codes = model.encode(wav)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
codes = model.encode(wav)
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
return [ code[:, :l * cfg.dataset.frames_per_second // cfg.sample_rate] for code, l in zip(codes, lens) ]