slight fixes
This commit is contained in:
parent
7110b878b7
commit
7fc8f4c45a
16
src/utils.py
16
src/utils.py
|
@ -172,8 +172,6 @@ if BARK_ENABLED:
|
|||
|
||||
try:
|
||||
from hubert.hubert_manager import HuBERTManager
|
||||
from hubert.pre_kmeans_hubert import CustomHubert
|
||||
from hubert.customtokenizer import CustomTokenizer
|
||||
|
||||
hubert_manager = HuBERTManager()
|
||||
hubert_manager.make_sure_hubert_installed()
|
||||
|
@ -244,6 +242,9 @@ if BARK_ENABLED:
|
|||
# generate semantic tokens
|
||||
|
||||
if self.hubert_enabled:
|
||||
from hubert.pre_kmeans_hubert import CustomHubert
|
||||
from hubert.customtokenizer import CustomTokenizer
|
||||
|
||||
wav = wav.to(self.device)
|
||||
|
||||
# Extract discrete codes from EnCodec
|
||||
|
@ -299,7 +300,7 @@ if BARK_ENABLED:
|
|||
semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
|
||||
audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )
|
||||
|
||||
if VOCOS_ENABLED:
|
||||
if self.vocos_enabled:
|
||||
audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
|
||||
features = self.vocos.codes_to_features(audio_tokens_torch)
|
||||
wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
|
||||
|
@ -510,7 +511,7 @@ def generate_bark(**kwargs):
|
|||
settings['datetime'] = datetime.now().isoformat()
|
||||
|
||||
# save here in case some error happens mid-batch
|
||||
if VOCOS_ENABLED:
|
||||
if tts.vocos_enabled:
|
||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||
else:
|
||||
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
|
||||
|
@ -1992,7 +1993,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
|||
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)
|
||||
|
||||
for line in iter(training_state.process.stdout.readline, ""):
|
||||
if training_state.killed:
|
||||
if training_state is None or training_state.killed:
|
||||
return
|
||||
|
||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
|
||||
|
@ -3359,12 +3360,11 @@ def setup_args(cli=False):
|
|||
|
||||
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
||||
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||
|
||||
|
||||
parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069")
|
||||
parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1")
|
||||
parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false")
|
||||
|
||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||
|
||||
if cli:
|
||||
args, unknown = parser.parse_known_args()
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user