slight fixes

This commit is contained in:
mrq 2023-09-03 12:34:55 +00:00
parent 7110b878b7
commit 7fc8f4c45a

View File

@ -172,8 +172,6 @@ if BARK_ENABLED:
try: try:
from hubert.hubert_manager import HuBERTManager from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager() hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed() hubert_manager.make_sure_hubert_installed()
@ -244,6 +242,9 @@ if BARK_ENABLED:
# generate semantic tokens # generate semantic tokens
if self.hubert_enabled: if self.hubert_enabled:
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
wav = wav.to(self.device) wav = wav.to(self.device)
# Extract discrete codes from EnCodec # Extract discrete codes from EnCodec
@ -299,7 +300,7 @@ if BARK_ENABLED:
semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False) semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False ) audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )
if VOCOS_ENABLED: if self.vocos_enabled:
audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device) audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
features = self.vocos.codes_to_features(audio_tokens_torch) features = self.vocos.codes_to_features(audio_tokens_torch)
wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device)) wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
@ -510,7 +511,7 @@ def generate_bark(**kwargs):
settings['datetime'] = datetime.now().isoformat() settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch # save here in case some error happens mid-batch
if VOCOS_ENABLED: if tts.vocos_enabled:
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
else: else:
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav) write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
@ -1992,7 +1993,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints) training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
if training_state.killed: if training_state is None or training_state.killed:
return return
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress ) result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
@ -3359,12 +3360,11 @@ def setup_args(cli=False):
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069") parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069")
parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1") parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1")
parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false") parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false")
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
if cli: if cli:
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
else: else: