From 299cc88821a099efed25535427dce87d75abc7ea Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 5 Feb 2025 21:55:06 -0600 Subject: [PATCH] re-added amp encoding/decoding for audio, possible bad idea to ignore using amp instead if requested --- vall_e/emb/process.py | 25 ++++++++----- vall_e/emb/qnt.py | 87 ++++++++++++++++++++++++++++--------------- 2 files changed, 71 insertions(+), 41 deletions(-) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 035b697..49fb22f 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -38,9 +38,9 @@ def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] -def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ): +def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda", dtype=None ): # encodec requires this to be on CPU for resampling - qnt = quantize(waveform, sr=sample_rate, device=device) + qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype) if cfg.audio_backend == "dac": state_dict = { @@ -75,10 +75,13 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic np.save(open(outpath, "wb"), state_dict) -def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ): +def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ): if not jobs: return + # sort to avoid egregious padding + jobs = sorted(jobs, key=lambda x: x[1].shape[-1], reverse=True) + buffer = [] batches = [] @@ -88,7 +91,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru batches.append(buffer) buffer = [] - if len(buffer) >= batch_size: + if buffer: batches.append(buffer) buffer = [] @@ -101,7 +104,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru srs.append(sample_rate) try: - codes = quantize_batch(wavs, sr=srs, device=device) + codes = quantize_batch(wavs, sr=srs, device=device, dtype=dtype) except Exception as e: _logger.error(f"Failed to quantize: {outpath}: {str(e)}") if raise_exceptions: @@ -142,18 +145,18 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru np.save(open(outpath, "wb"), state_dict) -def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ): +def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ): if not jobs: return # batch things if batch_size > 1: - return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size ) + return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype ) for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"): outpath, waveform, sample_rate, text, language = job try: - process_job( outpath, waveform, sample_rate, text, language, device ) + process_job( outpath, waveform, sample_rate, text, language, device, dtype=dtype ) except Exception as e: _logger.error(f"Failed to quantize: {outpath}: {str(e)}") if raise_exceptions: @@ -186,6 +189,8 @@ def process( cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False + dtype = None + output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" # to-do: make this also prepared from args @@ -334,12 +339,12 @@ def process( # processes audio files one at a time if low_memory: - process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size ) + process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) jobs = [] # processes all audio files for a given speaker if not low_memory: - process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size ) + process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) jobs = [] open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 09c6742..b093f1f 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -45,7 +45,7 @@ except Exception as e: _logger.warning(str(e)) @cache -def _load_encodec_model(device="cuda", levels=0): +def _load_encodec_model(device="cuda", dtype=None, levels=0): assert cfg.sample_rate == 24_000 if not levels: @@ -67,6 +67,9 @@ def _load_encodec_model(device="cuda", levels=0): model = model.to(device) model = model.eval() + if dtype is not None: + model = model.to(dtype) + # extra metadata model.bandwidth_id = bandwidth_id model.normalize = cfg.inference.normalize @@ -75,7 +78,7 @@ def _load_encodec_model(device="cuda", levels=0): return model @cache -def _load_vocos_model(device="cuda", levels=0): +def _load_vocos_model(device="cuda", dtype=None, levels=0): assert cfg.sample_rate == 24_000 if not levels: @@ -85,6 +88,9 @@ def _load_vocos_model(device="cuda", levels=0): model = model.to(device) model = model.eval() + if dtype is not None: + model = model.to(dtype) + # too lazy to un-if ladder this shit bandwidth_id = 2 if levels == 2: @@ -101,7 +107,7 @@ def _load_vocos_model(device="cuda", levels=0): return model @cache -def _load_dac_model(device="cuda"): +def _load_dac_model(device="cuda", dtype=None): kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest") # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' if cfg.sample_rate == 44_100: @@ -115,17 +121,25 @@ def _load_dac_model(device="cuda"): model = model.to(device) model = model.eval() + if dtype is not None: + model = model.to(dtype) + model.backend = "dac" model.model_type = kwargs["model_type"] return model @cache -def _load_nemo_model(device="cuda", model_name=None): +def _load_nemo_model(device="cuda", dtype=None, model_name=None): if not model_name: model_name = "nvidia/audio-codec-44khz" - model = AudioCodecModel.from_pretrained(model_name).to(device).eval() + model = AudioCodecModel.from_pretrained(model_name) + model = model.to(device) + model = model.eval() + + if dtype is not None: + model = model.to(dtype) model.backend = "nemo" @@ -133,20 +147,23 @@ def _load_nemo_model(device="cuda", model_name=None): @cache -def _load_model(device="cuda", backend=None): +def _load_model(device="cuda", backend=None, dtype=None): if not backend: backend = cfg.audio_backend - if backend == "nemo": - return _load_nemo_model(device) - if backend == "audiodec": - return _load_audiodec_model(device) - if backend == "dac": - return _load_dac_model(device) - if backend == "vocos": - return _load_vocos_model(device) + if cfg.inference.amp: + dtype = None - return _load_encodec_model(device) + if backend == "nemo": + return _load_nemo_model(device, dtype=dtype) + if backend == "audiodec": + return _load_audiodec_model(device, dtype=dtype) + if backend == "dac": + return _load_dac_model(device, dtype=dtype) + if backend == "vocos": + return _load_vocos_model(device, dtype=dtype) + + return _load_encodec_model(device, dtype=dtype) def unload_model(): _load_model.cache_clear() @@ -154,7 +171,7 @@ def unload_model(): # to-do: clean up this mess @torch.inference_mode() -def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None): +def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None): # upcast so it won't whine if codes.dtype in [torch.int8, torch.int16, torch.uint8]: codes = codes.to(torch.int32) @@ -174,7 +191,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None): assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' # load the model - model = _load_model(device) + model = _load_model(device, dtype=dtype) # move to device codes = codes.to( device=device ) @@ -231,7 +248,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None): return wav, cfg.sample_rate @torch.inference_mode() -def decode_batch(codes: list[Tensor], device="cuda"): +def decode_batch(codes: list[Tensor], device="cuda", dtype=None): # transpose if needed for i, code in enumerate(codes): if code.shape[0] < code.shape[1]: @@ -254,7 +271,7 @@ def decode_batch(codes: list[Tensor], device="cuda"): assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' # load the model - model = _load_model(device) + model = _load_model(device, dtype=dtype) # move to device codes = codes.to( device=device ) @@ -280,7 +297,7 @@ def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) @torch.inference_mode() -def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None): +def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, return_metadata=True, window_duration=None): # expand if 1D if wav.dim() < 2: wav = wav.unsqueeze(0) @@ -288,12 +305,16 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat if wav.dim() < 3: wav = wav.unsqueeze(0) + if dtype is not None: + wav = wav.to(dtype) + # cringe assert assert wav.shape[0] == 1, f'Batch encoding is unsupported with vanilla encode()' + + model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype ) # DAC uses a different pathway if cfg.audio_backend == "dac": - model = _load_dac_model( device ) signal = AudioSignal(wav, sample_rate=sr) artifact = model.compress(signal, win_duration=window_duration, verbose=False) @@ -307,23 +328,22 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat # NeMo uses a different pathway if cfg.audio_backend == "nemo": - model = _load_nemo_model( device ) - wav = wav.to(device)[:, 0, :] l = torch.tensor([w.shape[0] for w in wav]).to(device) - codes, lens = model.encode(audio=wav, audio_len=l) + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + codes, lens = model.encode(audio=wav, audio_len=l) # to-do: unpad return codes # vocos does not encode wavs to encodecs, so just use normal encodec if cfg.audio_backend in ["encodec", "vocos"]: - model = _load_encodec_model(device) - codes = model.encode(wav) + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + codes = model.encode(wav) codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t) return codes @torch.inference_mode() -def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda" ): +def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda", dtype=None ): # expand as list if not isinstance(sr, list): sr = [sr] * len(wavs) @@ -351,11 +371,16 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev # wav = wav.to(device) + if dtype is not None: + wav = wav.to(dtype) + + model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype ) + # NeMo uses a different pathway if cfg.audio_backend == "nemo": - model = _load_nemo_model( device ) wav = wav.to(device)[:, 0, :] - codes, code_lens = model.encode(audio=wav, audio_len=lens) + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + codes, code_lens = model.encode(audio=wav, audio_len=lens) return [ code[:, :l] for code, l in zip( codes, code_lens ) ] # can't be assed to implement @@ -364,8 +389,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev # naively encode if cfg.audio_backend in ["encodec", "vocos"]: - model = _load_encodec_model(device) - codes = model.encode(wav) + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + codes = model.encode(wav) codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t) return [ code[:, :l * cfg.dataset.frames_per_second // cfg.sample_rate] for code, l in zip(codes, lens) ]