2023-08-02 21:53:35 +00:00
from . . config import cfg
import argparse
import random
import torch
import torchaudio
from functools import cache
from pathlib import Path
2024-04-19 23:36:54 +00:00
from typing import Union
2023-08-02 21:53:35 +00:00
from einops import rearrange
from torch import Tensor
from tqdm import tqdm
2024-04-18 01:39:35 +00:00
try :
from encodec import EncodecModel
from encodec . utils import convert_audio
except Exception as e :
cfg . inference . use_encodec = False
2023-08-02 21:53:35 +00:00
try :
from vocos import Vocos
except Exception as e :
2023-08-02 23:36:26 +00:00
cfg . inference . use_vocos = False
2023-08-02 21:53:35 +00:00
2024-04-18 01:39:35 +00:00
try :
from dac import DACFile
from audiotools import AudioSignal
from dac . utils import load_model as __load_dac_model
"""
Patch decode to skip things related to the metadata ( namely the waveform trimming )
So far it seems the raw waveform can just be returned without any post - processing
A smart implementation would just reuse the values from the input prompt
"""
from dac . model . base import CodecMixin
@torch.no_grad ( )
def CodecMixin_decompress (
self ,
obj : Union [ str , Path , DACFile ] ,
verbose : bool = False ,
) - > AudioSignal :
self . eval ( )
if isinstance ( obj , ( str , Path ) ) :
obj = DACFile . load ( obj )
original_padding = self . padding
self . padding = obj . padding
range_fn = range if not verbose else tqdm . trange
codes = obj . codes
original_device = codes . device
chunk_length = obj . chunk_length
recons = [ ]
for i in range_fn ( 0 , codes . shape [ - 1 ] , chunk_length ) :
c = codes [ . . . , i : i + chunk_length ] . to ( self . device )
z = self . quantizer . from_codes ( c ) [ 0 ]
r = self . decode ( z )
recons . append ( r . to ( original_device ) )
recons = torch . cat ( recons , dim = - 1 )
recons = AudioSignal ( recons , self . sample_rate )
# to-do, original implementation
2024-04-19 23:36:54 +00:00
if not hasattr ( obj , " dummy " ) or not obj . dummy :
resample_fn = recons . resample
loudness_fn = recons . loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons . signal_duration > = 10 * 60 * 60 :
resample_fn = recons . ffmpeg_resample
loudness_fn = recons . ffmpeg_loudness
recons . normalize ( obj . input_db )
resample_fn ( obj . sample_rate )
recons = recons [ . . . , : obj . original_length ]
loudness_fn ( )
recons . audio_data = recons . audio_data . reshape (
- 1 , obj . channels , obj . original_length
)
2024-04-18 01:39:35 +00:00
self . padding = original_padding
return recons
CodecMixin . decompress = CodecMixin_decompress
except Exception as e :
cfg . inference . use_dac = False
2024-04-19 23:36:54 +00:00
print ( str ( e ) )
2023-08-02 21:53:35 +00:00
@cache
2024-04-16 00:54:32 +00:00
def _load_encodec_model ( device = " cuda " , levels = cfg . model . max_levels ) :
2023-08-02 21:53:35 +00:00
assert cfg . sample_rate == 24_000
# too lazy to un-if ladder this shit
2023-08-24 15:25:33 +00:00
bandwidth_id = 6.0
2023-08-19 20:06:33 +00:00
if levels == 2 :
2023-08-02 21:53:35 +00:00
bandwidth_id = 1.5
2023-08-19 20:06:33 +00:00
elif levels == 4 :
2023-08-02 21:53:35 +00:00
bandwidth_id = 3.0
2023-08-19 20:06:33 +00:00
elif levels == 8 :
2023-08-02 21:53:35 +00:00
bandwidth_id = 6.0
2024-04-18 01:39:35 +00:00
# Instantiate a pretrained EnCodec model
model = EncodecModel . encodec_model_24khz ( )
2023-08-02 21:53:35 +00:00
model . set_target_bandwidth ( bandwidth_id )
2024-04-18 01:39:35 +00:00
model = model . to ( device )
model = model . eval ( )
# extra metadata
2023-08-02 23:36:26 +00:00
model . bandwidth_id = bandwidth_id
model . sample_rate = cfg . sample_rate
2023-08-19 04:55:40 +00:00
model . normalize = cfg . inference . normalize
2023-08-02 23:36:26 +00:00
model . backend = " encodec "
2023-08-02 21:53:35 +00:00
return model
@cache
2024-04-16 00:54:32 +00:00
def _load_vocos_model ( device = " cuda " , levels = cfg . model . max_levels ) :
2023-08-02 21:53:35 +00:00
assert cfg . sample_rate == 24_000
model = Vocos . from_pretrained ( " charactr/vocos-encodec-24khz " )
model = model . to ( device )
2024-04-18 01:39:35 +00:00
model = model . eval ( )
2023-08-02 21:53:35 +00:00
# too lazy to un-if ladder this shit
2023-08-24 15:25:33 +00:00
bandwidth_id = 2
2023-08-19 20:06:33 +00:00
if levels == 2 :
2023-08-02 21:53:35 +00:00
bandwidth_id = 0
2023-08-19 20:06:33 +00:00
elif levels == 4 :
2023-08-02 21:53:35 +00:00
bandwidth_id = 1
2023-08-19 20:06:33 +00:00
elif levels == 8 :
2023-08-02 21:53:35 +00:00
bandwidth_id = 2
2024-04-18 01:39:35 +00:00
# extra metadata
2023-08-02 21:53:35 +00:00
model . bandwidth_id = torch . tensor ( [ bandwidth_id ] , device = device )
model . sample_rate = cfg . sample_rate
2023-08-02 23:36:26 +00:00
model . backend = " vocos "
2023-08-02 21:53:35 +00:00
return model
@cache
2024-04-18 01:39:35 +00:00
def _load_dac_model ( device = " cuda " , levels = cfg . model . max_levels ) :
2024-05-10 03:33:40 +00:00
kwargs = dict ( model_type = " 24khz " , model_bitrate = " 8kbps " , tag = " latest " )
2024-05-05 04:49:15 +00:00
if not cfg . variable_sample_rate :
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
if cfg . sample_rate == 44_000 :
kwargs [ " model_type " ] = " 44kz "
elif cfg . sample_rate == 24_000 :
kwargs [ " model_type " ] = " 24khz "
elif cfg . sample_rate == 16_000 :
kwargs [ " model_type " ] = " 16khz "
else :
raise Exception ( f ' unsupported sample rate: { cfg . sample_rate } ' )
2024-04-18 01:39:35 +00:00
model = __load_dac_model ( * * kwargs )
model = model . to ( device )
model = model . eval ( )
# extra metadata
2024-04-19 23:36:54 +00:00
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
# updating it here will affect the sample rate the waveform is resampled to on encoding
if cfg . variable_sample_rate :
model . sample_rate = cfg . sample_rate
2024-05-05 04:09:18 +00:00
2024-04-18 01:39:35 +00:00
model . backend = " dac "
2024-05-05 04:09:18 +00:00
model . model_type = kwargs [ " model_type " ]
2023-08-02 21:53:35 +00:00
return model
2024-04-18 01:39:35 +00:00
@cache
def _load_model ( device = " cuda " , backend = cfg . inference . audio_backend , levels = cfg . model . max_levels ) :
if backend == " dac " :
return _load_dac_model ( device , levels = levels )
if backend == " vocos " :
return _load_vocos_model ( device , levels = levels )
return _load_encodec_model ( device , levels = levels )
2023-08-02 21:53:35 +00:00
def unload_model ( ) :
_load_model . cache_clear ( )
2024-04-18 01:39:35 +00:00
_load_encodec_model . cache_clear ( ) # because vocos can only decode
2023-08-02 21:53:35 +00:00
@torch.inference_mode ( )
2024-04-18 01:39:35 +00:00
def decode ( codes : Tensor , device = " cuda " , levels = cfg . model . max_levels , metadata = None ) :
# upcast so it won't whine
if codes . dtype == torch . int8 or codes . dtype == torch . int16 or codes . dtype == torch . uint8 :
codes = codes . to ( torch . int32 )
2023-08-02 21:53:35 +00:00
# expand if we're given a raw 1-RVQ stream
if codes . dim ( ) == 1 :
codes = rearrange ( codes , " t -> 1 1 t " )
# expand to a batch size of one if not passed as a batch
# vocos does not do batch decoding, but encodec does, but we don't end up using this anyways *I guess*
# to-do, make this logical
elif codes . dim ( ) == 2 :
codes = rearrange ( codes , " t q -> 1 q t " )
assert codes . dim ( ) == 3 , f ' Requires shape (b q t) but got { codes . shape } '
2024-04-18 01:39:35 +00:00
# load the model
2023-08-19 20:06:33 +00:00
model = _load_model ( device , levels = levels )
2023-08-02 21:53:35 +00:00
2024-04-18 01:39:35 +00:00
# DAC uses a different pathway
if model . backend == " dac " :
2024-04-19 23:36:54 +00:00
dummy = False
2024-04-18 01:39:35 +00:00
if metadata is None :
metadata = dict (
2024-05-12 12:30:59 +00:00
chunk_length = codes . shape [ - 1 ] ,
2024-04-18 01:39:35 +00:00
original_length = 0 ,
input_db = - 12 ,
channels = 1 ,
sample_rate = model . sample_rate ,
2024-05-12 12:30:59 +00:00
padding = True ,
2024-04-18 01:39:35 +00:00
dac_version = ' 1.0.0 ' ,
)
2024-04-19 23:36:54 +00:00
dummy = True
2024-04-18 01:39:35 +00:00
# generate object with copied metadata
artifact = DACFile (
codes = codes ,
# yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere
chunk_length = metadata [ " chunk_length " ] if isinstance ( metadata , dict ) else metadata . chunk_length ,
original_length = metadata [ " original_length " ] if isinstance ( metadata , dict ) else metadata . original_length ,
input_db = metadata [ " input_db " ] if isinstance ( metadata , dict ) else metadata . input_db ,
channels = metadata [ " channels " ] if isinstance ( metadata , dict ) else metadata . channels ,
sample_rate = metadata [ " sample_rate " ] if isinstance ( metadata , dict ) else metadata . sample_rate ,
padding = metadata [ " padding " ] if isinstance ( metadata , dict ) else metadata . padding ,
dac_version = metadata [ " dac_version " ] if isinstance ( metadata , dict ) else metadata . dac_version ,
)
2024-04-19 23:36:54 +00:00
artifact . dummy = dummy
2024-04-18 01:39:35 +00:00
2024-04-19 23:36:54 +00:00
# to-do: inject the sample rate encoded at, because we can actually decouple
return CodecMixin_decompress ( model , artifact , verbose = False ) . audio_data [ 0 ] , artifact . sample_rate
2024-04-18 01:39:35 +00:00
2023-08-02 21:53:35 +00:00
kwargs = { }
2023-08-02 23:36:26 +00:00
if model . backend == " vocos " :
2023-08-02 21:53:35 +00:00
x = model . codes_to_features ( codes [ 0 ] )
kwargs [ ' bandwidth_id ' ] = model . bandwidth_id
else :
2024-04-18 01:39:35 +00:00
# encodec will decode as a batch
2023-08-02 21:53:35 +00:00
x = [ ( codes . to ( device ) , None ) ]
wav = model . decode ( x , * * kwargs )
2024-04-18 01:39:35 +00:00
# encodec will decode as a batch
2023-08-02 23:36:26 +00:00
if model . backend == " encodec " :
2023-08-02 21:53:35 +00:00
wav = wav [ 0 ]
return wav , model . sample_rate
# huh
2024-04-16 00:54:32 +00:00
def decode_to_wave ( resps : Tensor , device = " cuda " , levels = cfg . model . max_levels ) :
2023-08-19 20:06:33 +00:00
return decode ( resps , device = device , levels = levels )
2023-08-02 21:53:35 +00:00
def decode_to_file ( resps : Tensor , path : Path , device = " cuda " ) :
wavs , sr = decode ( resps , device = device )
torchaudio . save ( str ( path ) , wavs . cpu ( ) , sr )
return wavs , sr
def _replace_file_extension ( path , suffix ) :
return ( path . parent / path . name . split ( " . " ) [ 0 ] ) . with_suffix ( suffix )
@torch.inference_mode ( )
2024-04-18 18:32:41 +00:00
def encode ( wav : Tensor , sr : int = cfg . sample_rate , device = " cuda " , levels = cfg . model . max_levels , return_metadata = True ) :
2024-04-18 01:39:35 +00:00
if cfg . inference . audio_backend == " dac " :
2024-05-08 07:11:38 +00:00
model = _load_dac_model ( device , levels = levels )
2024-04-18 18:32:41 +00:00
signal = AudioSignal ( wav , sample_rate = sr )
2024-05-08 07:11:38 +00:00
if not isinstance ( levels , int ) :
levels = 8 if model . model_type == " 24khz " else None
with torch . autocast ( " cuda " , dtype = torch . bfloat16 , enabled = False ) : # or True for about 2x speed, not enabling by default for systems that do not have bfloat16
artifact = model . compress ( signal , win_duration = None , verbose = False , n_quantizers = levels )
2024-05-05 04:09:18 +00:00
# trim to 8 codebooks if 24Khz
2024-05-08 07:11:38 +00:00
# probably redundant with levels, should rewrite logic eventuall
2024-05-05 04:09:18 +00:00
if model . model_type == " 24khz " :
artifact . codes = artifact . codes [ : , : 8 , : ]
2024-04-18 01:39:35 +00:00
return artifact . codes if not return_metadata else artifact
# vocos does not encode wavs to encodecs, so just use normal encodec
2023-08-19 20:06:33 +00:00
model = _load_encodec_model ( device , levels = levels )
2023-08-02 21:53:35 +00:00
wav = wav . unsqueeze ( 0 )
wav = convert_audio ( wav , sr , model . sample_rate , model . channels )
wav = wav . to ( device )
encoded_frames = model . encode ( wav )
qnt = torch . cat ( [ encoded [ 0 ] for encoded in encoded_frames ] , dim = - 1 ) # (b q t)
return qnt
def encode_from_files ( paths , device = " cuda " ) :
tuples = [ torchaudio . load ( str ( path ) ) for path in paths ]
wavs = [ ]
main_sr = tuples [ 0 ] [ 1 ]
for wav , sr in tuples :
assert sr == main_sr , " Mismatching sample rates "
if wav . shape [ 0 ] == 2 :
wav = wav [ : 1 ]
wavs . append ( wav )
wav = torch . cat ( wavs , dim = - 1 )
return encode ( wav , sr , " cpu " )
def encode_from_file ( path , device = " cuda " ) :
if isinstance ( path , list ) :
return encode_from_files ( path , device )
else :
2023-08-14 03:07:45 +00:00
path = str ( path )
2023-10-17 00:30:38 +00:00
wav , sr = torchaudio . load ( path )
2023-08-02 21:53:35 +00:00
if wav . shape [ 0 ] == 2 :
wav = wav [ : 1 ]
qnt = encode ( wav , sr , device )
return qnt
2024-04-18 01:39:35 +00:00
"""
Helper Functions
"""
2023-08-23 21:43:03 +00:00
# trims from the start, up to `target`
def trim ( qnt , target ) :
length = max ( qnt . shape [ 0 ] , qnt . shape [ 1 ] )
2023-09-01 22:19:34 +00:00
if target > 0 :
start = 0
end = start + target
if end > = length :
start = length - target
end = length
# negative length specified, trim from end
else :
start = length + target
end = length
if start < 0 :
start = 0
2023-08-23 21:43:03 +00:00
return qnt [ start : end ] if qnt . shape [ 0 ] > qnt . shape [ 1 ] else qnt [ : , start : end ]
2023-08-19 03:22:13 +00:00
# trims a random piece of audio, up to `target`
2023-08-23 21:43:03 +00:00
# to-do: try and align to EnCodec window
2023-08-19 03:22:13 +00:00
def trim_random ( qnt , target ) :
2023-08-21 02:36:02 +00:00
length = max ( qnt . shape [ 0 ] , qnt . shape [ 1 ] )
2023-08-19 03:22:13 +00:00
start = int ( length * random . random ( ) )
end = start + target
if end > = length :
start = length - target
2024-04-18 01:39:35 +00:00
end = length
2023-08-19 03:22:13 +00:00
2023-08-21 02:36:02 +00:00
return qnt [ start : end ] if qnt . shape [ 0 ] > qnt . shape [ 1 ] else qnt [ : , start : end ]
2023-08-19 03:22:13 +00:00
# repeats the audio to fit the target size
def repeat_extend_audio ( qnt , target ) :
pieces = [ ]
length = 0
while length < target :
pieces . append ( qnt )
length + = qnt . shape [ 0 ]
2023-08-27 00:53:23 +00:00
return trim ( torch . cat ( pieces ) , target )
2023-08-19 03:22:13 +00:00
# merges two quantized audios together
# I don't know if this works
2024-04-16 00:54:32 +00:00
def merge_audio ( * args , device = " cpu " , scale = [ ] , levels = cfg . model . max_levels ) :
2023-08-19 03:22:13 +00:00
qnts = [ * args ]
2023-08-19 20:06:33 +00:00
decoded = [ decode ( qnt , device = device , levels = levels ) [ 0 ] for qnt in qnts ]
2023-08-19 04:55:40 +00:00
if len ( scale ) == len ( decoded ) :
for i in range ( len ( scale ) ) :
decoded [ i ] = decoded [ i ] * scale [ i ]
2023-08-19 03:22:13 +00:00
combined = sum ( decoded ) / len ( decoded )
2024-04-18 01:39:35 +00:00
return encode ( combined , cfg . sample_rate , device = " cpu " , levels = levels ) [ 0 ] . t ( )
2023-08-02 21:53:35 +00:00
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " folder " , type = Path )
parser . add_argument ( " --suffix " , default = " .wav " )
2023-08-19 04:55:40 +00:00
parser . add_argument ( " --device " , default = " cuda " )
2024-04-18 01:39:35 +00:00
parser . add_argument ( " --backend " , default = " encodec " )
2023-08-02 21:53:35 +00:00
args = parser . parse_args ( )
2023-08-19 04:55:40 +00:00
device = args . device
2023-08-02 21:53:35 +00:00
paths = [ * args . folder . rglob ( f " * { args . suffix } " ) ]
for path in tqdm ( paths ) :
out_path = _replace_file_extension ( path , " .qnt.pt " )
if out_path . exists ( ) :
continue
2023-08-19 04:55:40 +00:00
qnt = encode_from_file ( path , device = device )
2023-08-02 21:53:35 +00:00
torch . save ( qnt . cpu ( ) , out_path )
if __name__ == " __main__ " :
main ( )