diff --git a/src/utils.py b/src/utils.py index 6457b77..435cd4d 100755 --- a/src/utils.py +++ b/src/utils.py @@ -172,8 +172,6 @@ if BARK_ENABLED: try: from hubert.hubert_manager import HuBERTManager - from hubert.pre_kmeans_hubert import CustomHubert - from hubert.customtokenizer import CustomTokenizer hubert_manager = HuBERTManager() hubert_manager.make_sure_hubert_installed() @@ -244,6 +242,9 @@ if BARK_ENABLED: # generate semantic tokens if self.hubert_enabled: + from hubert.pre_kmeans_hubert import CustomHubert + from hubert.customtokenizer import CustomTokenizer + wav = wav.to(self.device) # Extract discrete codes from EnCodec @@ -299,7 +300,7 @@ if BARK_ENABLED: semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False) audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False ) - if VOCOS_ENABLED: + if self.vocos_enabled: audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device) features = self.vocos.codes_to_features(audio_tokens_torch) wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device)) @@ -510,7 +511,7 @@ def generate_bark(**kwargs): settings['datetime'] = datetime.now().isoformat() # save here in case some error happens mid-batch - if VOCOS_ENABLED: + if tts.vocos_enabled: torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) else: write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav) @@ -1992,7 +1993,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints) for line in iter(training_state.process.stdout.readline, ""): - if training_state.killed: + if training_state is None or training_state.killed: return result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress ) @@ -3359,12 +3360,11 @@ def setup_args(cli=False): parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") - + parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069") parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1") parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false") - - parser.add_argument("--os", default="unix", help="Specifies which OS, easily") + if cli: args, unknown = parser.parse_known_args() else: