re-added amp encoding/decoding for audio, possible bad idea to ignore using amp instead if requested
This commit is contained in:
parent
7592befc53
commit
299cc88821
|
@ -38,9 +38,9 @@ def process_items( items, stride=0, stride_offset=0 ):
|
|||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda" ):
|
||||
def process_job( outpath, waveform, sample_rate, text=None, language="en", device="cuda", dtype=None ):
|
||||
# encodec requires this to be on CPU for resampling
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device)
|
||||
qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype)
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
state_dict = {
|
||||
|
@ -75,10 +75,13 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
|
|||
|
||||
np.save(open(outpath, "wb"), state_dict)
|
||||
|
||||
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
||||
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ):
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
# sort to avoid egregious padding
|
||||
jobs = sorted(jobs, key=lambda x: x[1].shape[-1], reverse=True)
|
||||
|
||||
buffer = []
|
||||
batches = []
|
||||
|
||||
|
@ -88,7 +91,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
|
|||
batches.append(buffer)
|
||||
buffer = []
|
||||
|
||||
if len(buffer) >= batch_size:
|
||||
if buffer:
|
||||
batches.append(buffer)
|
||||
buffer = []
|
||||
|
||||
|
@ -101,7 +104,7 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
|
|||
srs.append(sample_rate)
|
||||
|
||||
try:
|
||||
codes = quantize_batch(wavs, sr=srs, device=device)
|
||||
codes = quantize_batch(wavs, sr=srs, device=device, dtype=dtype)
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||
if raise_exceptions:
|
||||
|
@ -142,18 +145,18 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
|
|||
|
||||
np.save(open(outpath, "wb"), state_dict)
|
||||
|
||||
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
||||
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1, dtype=None ):
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
# batch things
|
||||
if batch_size > 1:
|
||||
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype )
|
||||
|
||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||
outpath, waveform, sample_rate, text, language = job
|
||||
try:
|
||||
process_job( outpath, waveform, sample_rate, text, language, device )
|
||||
process_job( outpath, waveform, sample_rate, text, language, device, dtype=dtype )
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||
if raise_exceptions:
|
||||
|
@ -186,6 +189,8 @@ def process(
|
|||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||
cfg.inference.amp = amp # False
|
||||
|
||||
dtype = None
|
||||
|
||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||
|
||||
# to-do: make this also prepared from args
|
||||
|
@ -334,12 +339,12 @@ def process(
|
|||
|
||||
# processes audio files one at a time
|
||||
if low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
jobs = []
|
||||
|
||||
# processes all audio files for a given speaker
|
||||
if not low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
jobs = []
|
||||
|
||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
|
|
|
@ -45,7 +45,7 @@ except Exception as e:
|
|||
_logger.warning(str(e))
|
||||
|
||||
@cache
|
||||
def _load_encodec_model(device="cuda", levels=0):
|
||||
def _load_encodec_model(device="cuda", dtype=None, levels=0):
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
||||
if not levels:
|
||||
|
@ -67,6 +67,9 @@ def _load_encodec_model(device="cuda", levels=0):
|
|||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# extra metadata
|
||||
model.bandwidth_id = bandwidth_id
|
||||
model.normalize = cfg.inference.normalize
|
||||
|
@ -75,7 +78,7 @@ def _load_encodec_model(device="cuda", levels=0):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_vocos_model(device="cuda", levels=0):
|
||||
def _load_vocos_model(device="cuda", dtype=None, levels=0):
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
||||
if not levels:
|
||||
|
@ -85,6 +88,9 @@ def _load_vocos_model(device="cuda", levels=0):
|
|||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# too lazy to un-if ladder this shit
|
||||
bandwidth_id = 2
|
||||
if levels == 2:
|
||||
|
@ -101,7 +107,7 @@ def _load_vocos_model(device="cuda", levels=0):
|
|||
return model
|
||||
|
||||
@cache
|
||||
def _load_dac_model(device="cuda"):
|
||||
def _load_dac_model(device="cuda", dtype=None):
|
||||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
if cfg.sample_rate == 44_100:
|
||||
|
@ -115,17 +121,25 @@ def _load_dac_model(device="cuda"):
|
|||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
model.backend = "dac"
|
||||
model.model_type = kwargs["model_type"]
|
||||
|
||||
return model
|
||||
|
||||
@cache
|
||||
def _load_nemo_model(device="cuda", model_name=None):
|
||||
def _load_nemo_model(device="cuda", dtype=None, model_name=None):
|
||||
if not model_name:
|
||||
model_name = "nvidia/audio-codec-44khz"
|
||||
|
||||
model = AudioCodecModel.from_pretrained(model_name).to(device).eval()
|
||||
model = AudioCodecModel.from_pretrained(model_name)
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
model.backend = "nemo"
|
||||
|
||||
|
@ -133,20 +147,23 @@ def _load_nemo_model(device="cuda", model_name=None):
|
|||
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", backend=None):
|
||||
def _load_model(device="cuda", backend=None, dtype=None):
|
||||
if not backend:
|
||||
backend = cfg.audio_backend
|
||||
|
||||
if backend == "nemo":
|
||||
return _load_nemo_model(device)
|
||||
if backend == "audiodec":
|
||||
return _load_audiodec_model(device)
|
||||
if backend == "dac":
|
||||
return _load_dac_model(device)
|
||||
if backend == "vocos":
|
||||
return _load_vocos_model(device)
|
||||
if cfg.inference.amp:
|
||||
dtype = None
|
||||
|
||||
return _load_encodec_model(device)
|
||||
if backend == "nemo":
|
||||
return _load_nemo_model(device, dtype=dtype)
|
||||
if backend == "audiodec":
|
||||
return _load_audiodec_model(device, dtype=dtype)
|
||||
if backend == "dac":
|
||||
return _load_dac_model(device, dtype=dtype)
|
||||
if backend == "vocos":
|
||||
return _load_vocos_model(device, dtype=dtype)
|
||||
|
||||
return _load_encodec_model(device, dtype=dtype)
|
||||
|
||||
def unload_model():
|
||||
_load_model.cache_clear()
|
||||
|
@ -154,7 +171,7 @@ def unload_model():
|
|||
|
||||
# to-do: clean up this mess
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
||||
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
|
||||
# upcast so it won't whine
|
||||
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||
codes = codes.to(torch.int32)
|
||||
|
@ -174,7 +191,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
|||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||
|
||||
# load the model
|
||||
model = _load_model(device)
|
||||
model = _load_model(device, dtype=dtype)
|
||||
# move to device
|
||||
codes = codes.to( device=device )
|
||||
|
||||
|
@ -231,7 +248,7 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
|
|||
return wav, cfg.sample_rate
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_batch(codes: list[Tensor], device="cuda"):
|
||||
def decode_batch(codes: list[Tensor], device="cuda", dtype=None):
|
||||
# transpose if needed
|
||||
for i, code in enumerate(codes):
|
||||
if code.shape[0] < code.shape[1]:
|
||||
|
@ -254,7 +271,7 @@ def decode_batch(codes: list[Tensor], device="cuda"):
|
|||
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
|
||||
|
||||
# load the model
|
||||
model = _load_model(device)
|
||||
model = _load_model(device, dtype=dtype)
|
||||
# move to device
|
||||
codes = codes.to( device=device )
|
||||
|
||||
|
@ -280,7 +297,7 @@ def _replace_file_extension(path, suffix):
|
|||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
|
||||
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", dtype=None, return_metadata=True, window_duration=None):
|
||||
# expand if 1D
|
||||
if wav.dim() < 2:
|
||||
wav = wav.unsqueeze(0)
|
||||
|
@ -288,12 +305,16 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
|||
if wav.dim() < 3:
|
||||
wav = wav.unsqueeze(0)
|
||||
|
||||
if dtype is not None:
|
||||
wav = wav.to(dtype)
|
||||
|
||||
# cringe assert
|
||||
assert wav.shape[0] == 1, f'Batch encoding is unsupported with vanilla encode()'
|
||||
|
||||
model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype )
|
||||
|
||||
# DAC uses a different pathway
|
||||
if cfg.audio_backend == "dac":
|
||||
model = _load_dac_model( device )
|
||||
signal = AudioSignal(wav, sample_rate=sr)
|
||||
|
||||
artifact = model.compress(signal, win_duration=window_duration, verbose=False)
|
||||
|
@ -307,23 +328,22 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
|||
|
||||
# NeMo uses a different pathway
|
||||
if cfg.audio_backend == "nemo":
|
||||
model = _load_nemo_model( device )
|
||||
|
||||
wav = wav.to(device)[:, 0, :]
|
||||
l = torch.tensor([w.shape[0] for w in wav]).to(device)
|
||||
codes, lens = model.encode(audio=wav, audio_len=l)
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
codes, lens = model.encode(audio=wav, audio_len=l)
|
||||
# to-do: unpad
|
||||
return codes
|
||||
|
||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||
if cfg.audio_backend in ["encodec", "vocos"]:
|
||||
model = _load_encodec_model(device)
|
||||
codes = model.encode(wav)
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
codes = model.encode(wav)
|
||||
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
|
||||
return codes
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda" ):
|
||||
def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, device="cuda", dtype=None ):
|
||||
# expand as list
|
||||
if not isinstance(sr, list):
|
||||
sr = [sr] * len(wavs)
|
||||
|
@ -351,11 +371,16 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
|
|||
#
|
||||
wav = wav.to(device)
|
||||
|
||||
if dtype is not None:
|
||||
wav = wav.to(dtype)
|
||||
|
||||
model = _load_encodec_model( device, dtype=dtype ) if cfg.audio_backend == "vocos" else _load_model( device, dtype=dtype )
|
||||
|
||||
# NeMo uses a different pathway
|
||||
if cfg.audio_backend == "nemo":
|
||||
model = _load_nemo_model( device )
|
||||
wav = wav.to(device)[:, 0, :]
|
||||
codes, code_lens = model.encode(audio=wav, audio_len=lens)
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
codes, code_lens = model.encode(audio=wav, audio_len=lens)
|
||||
return [ code[:, :l] for code, l in zip( codes, code_lens ) ]
|
||||
|
||||
# can't be assed to implement
|
||||
|
@ -364,8 +389,8 @@ def encode_batch( wavs: list[Tensor], sr: list[int] | int = cfg.sample_rate, dev
|
|||
|
||||
# naively encode
|
||||
if cfg.audio_backend in ["encodec", "vocos"]:
|
||||
model = _load_encodec_model(device)
|
||||
codes = model.encode(wav)
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
codes = model.encode(wav)
|
||||
codes = torch.cat([code[0] for code in codes], dim=-1) # (b q t)
|
||||
|
||||
return [ code[:, :l * cfg.dataset.frames_per_second // cfg.sample_rate] for code, l in zip(codes, lens) ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user