@ -172,8 +172,6 @@ if BARK_ENABLED:
try :
from hubert . hubert_manager import HuBERTManager
from hubert . pre_kmeans_hubert import CustomHubert
from hubert . customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager ( )
hubert_manager . make_sure_hubert_installed ( )
@ -244,6 +242,9 @@ if BARK_ENABLED:
# generate semantic tokens
if self . hubert_enabled :
from hubert . pre_kmeans_hubert import CustomHubert
from hubert . customtokenizer import CustomTokenizer
wav = wav . to ( self . device )
# Extract discrete codes from EnCodec
@ -299,7 +300,7 @@ if BARK_ENABLED:
semantic_tokens = text_to_semantic ( text , history_prompt = voice , temp = text_temp , silent = False )
audio_tokens = semantic_to_audio_tokens ( semantic_tokens , history_prompt = voice , temp = waveform_temp , silent = False , output_full = False )
if VOCOS_ENABLED :
if self . vocos_enabled :
audio_tokens_torch = torch . from_numpy ( audio_tokens ) . to ( self . device )
features = self . vocos . codes_to_features ( audio_tokens_torch )
wav = self . vocos . decode ( features , bandwidth_id = torch . tensor ( [ 2 ] , device = self . device ) )
@ -510,7 +511,7 @@ def generate_bark(**kwargs):
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
# save here in case some error happens mid-batch
if VOCOS_ENABLED :
if tts. vocos_enabled :
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , wav . cpu ( ) , sr )
else :
write_wav ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , sr , wav )
@ -1992,7 +1993,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
training_state = TrainingState ( config_path = config_path , keep_x_past_checkpoints = keep_x_past_checkpoints )
for line in iter ( training_state . process . stdout . readline , " " ) :
if training_state . killed :
if training_state is None or training_state . killed :
return
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_checkpoints = keep_x_past_checkpoints , progress = progress )
@ -3359,12 +3360,11 @@ def setup_args(cli=False):
parser . add_argument ( " --training-default-halfp " , action = ' store_true ' , default = default_arguments [ ' training-default-halfp ' ] , help = " Training default: halfp " )
parser . add_argument ( " --training-default-bnb " , action = ' store_true ' , default = default_arguments [ ' training-default-bnb ' ] , help = " Training default: bnb " )
parser . add_argument ( " --websocket-listen-port " , type = int , default = default_arguments [ ' websocket-listen-port ' ] , help = " Websocket server listen port, default: 8069 " )
parser . add_argument ( " --websocket-listen-address " , default = default_arguments [ ' websocket-listen-address ' ] , help = " Websocket server listen address, default: 127.0.0.1 " )
parser . add_argument ( " --websocket-enabled " , action = ' store_true ' , default = default_arguments [ ' websocket-enabled ' ] , help = " Websocket API server enabled, default: false " )
parser . add_argument ( " --os " , default = " unix " , help = " Specifies which OS, easily " )
if cli :
args , unknown = parser . parse_known_args ( )
else :