slight fixes

This commit is contained in:
mrq 2023-09-03 12:34:55 +00:00
parent 7110b878b7
commit 7fc8f4c45a

View File

@ -172,8 +172,6 @@ if BARK_ENABLED:
try:
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
@ -244,6 +242,9 @@ if BARK_ENABLED:
# generate semantic tokens
if self.hubert_enabled:
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
wav = wav.to(self.device)
# Extract discrete codes from EnCodec
@ -299,7 +300,7 @@ if BARK_ENABLED:
semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )
if VOCOS_ENABLED:
if self.vocos_enabled:
audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
features = self.vocos.codes_to_features(audio_tokens_torch)
wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
@ -510,7 +511,7 @@ def generate_bark(**kwargs):
settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch
if VOCOS_ENABLED:
if tts.vocos_enabled:
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
else:
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
@ -1992,7 +1993,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)
for line in iter(training_state.process.stdout.readline, ""):
if training_state.killed:
if training_state is None or training_state.killed:
return
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
@ -3364,7 +3365,6 @@ def setup_args(cli=False):
parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1")
parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false")
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
if cli:
args, unknown = parser.parse_known_args()
else: