slight fixes
This commit is contained in:
parent
7110b878b7
commit
7fc8f4c45a
16
src/utils.py
16
src/utils.py
|
@ -172,8 +172,6 @@ if BARK_ENABLED:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from hubert.hubert_manager import HuBERTManager
|
from hubert.hubert_manager import HuBERTManager
|
||||||
from hubert.pre_kmeans_hubert import CustomHubert
|
|
||||||
from hubert.customtokenizer import CustomTokenizer
|
|
||||||
|
|
||||||
hubert_manager = HuBERTManager()
|
hubert_manager = HuBERTManager()
|
||||||
hubert_manager.make_sure_hubert_installed()
|
hubert_manager.make_sure_hubert_installed()
|
||||||
|
@ -244,6 +242,9 @@ if BARK_ENABLED:
|
||||||
# generate semantic tokens
|
# generate semantic tokens
|
||||||
|
|
||||||
if self.hubert_enabled:
|
if self.hubert_enabled:
|
||||||
|
from hubert.pre_kmeans_hubert import CustomHubert
|
||||||
|
from hubert.customtokenizer import CustomTokenizer
|
||||||
|
|
||||||
wav = wav.to(self.device)
|
wav = wav.to(self.device)
|
||||||
|
|
||||||
# Extract discrete codes from EnCodec
|
# Extract discrete codes from EnCodec
|
||||||
|
@ -299,7 +300,7 @@ if BARK_ENABLED:
|
||||||
semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
|
semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
|
||||||
audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )
|
audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )
|
||||||
|
|
||||||
if VOCOS_ENABLED:
|
if self.vocos_enabled:
|
||||||
audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
|
audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
|
||||||
features = self.vocos.codes_to_features(audio_tokens_torch)
|
features = self.vocos.codes_to_features(audio_tokens_torch)
|
||||||
wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
|
wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
|
||||||
|
@ -510,7 +511,7 @@ def generate_bark(**kwargs):
|
||||||
settings['datetime'] = datetime.now().isoformat()
|
settings['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
# save here in case some error happens mid-batch
|
# save here in case some error happens mid-batch
|
||||||
if VOCOS_ENABLED:
|
if tts.vocos_enabled:
|
||||||
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||||
else:
|
else:
|
||||||
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
|
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
|
||||||
|
@ -1992,7 +1993,7 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
||||||
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)
|
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
if training_state.killed:
|
if training_state is None or training_state.killed:
|
||||||
return
|
return
|
||||||
|
|
||||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
|
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
|
||||||
|
@ -3359,12 +3360,11 @@ def setup_args(cli=False):
|
||||||
|
|
||||||
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
||||||
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||||
|
|
||||||
parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069")
|
parser.add_argument("--websocket-listen-port", type=int, default=default_arguments['websocket-listen-port'], help="Websocket server listen port, default: 8069")
|
||||||
parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1")
|
parser.add_argument("--websocket-listen-address", default=default_arguments['websocket-listen-address'], help="Websocket server listen address, default: 127.0.0.1")
|
||||||
parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false")
|
parser.add_argument("--websocket-enabled", action='store_true', default=default_arguments['websocket-enabled'], help="Websocket API server enabled, default: false")
|
||||||
|
|
||||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
|
||||||
if cli:
|
if cli:
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user