Slight fix, getting close to be able to train from the web UI directly

This commit is contained in:
mrq 2023-02-17 13:57:03 +00:00
parent 8482131e10
commit f87764e7d0
3 changed files with 16 additions and 13 deletions

View File

@ -447,9 +447,9 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
"save_rate": save_rate if save_rate else 50,
"name": name if name else "finetune",
"dataset_name": dataset_name if dataset_name else "finetune",
"dataset_path": dataset_path if dataset_path else "./experiments/finetune/train.txt",
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./experiments/finetune/val.txt",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
}
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f:
@ -462,7 +462,7 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
f.write(yaml)
whisper_model = None
def prepare_dataset( files, outdir ):
def prepare_dataset( files, outdir, language=None ):
global whisper_model
if whisper_model is None:
whisper_model = whisper.load_model(args.whisper_model)
@ -476,7 +476,7 @@ def prepare_dataset( files, outdir ):
for file in files:
print(f"Transcribing file: {file}")
result = whisper_model.transcribe(file)
result = whisper_model.transcribe(file, language=language)
results[os.path.basename(file)] = result
print(f"Transcribed file: {file}, {len(result['segments'])} found.")

View File

@ -375,14 +375,15 @@ def setup_gradio():
with gr.Column():
dataset_settings = [
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
gr.Textbox(label="Language", placeholder="English")
]
dataset_voices = dataset_settings[0]
with gr.Column():
prepare_dataset_button = gr.Button(value="Prepare")
def prepare_dataset_proxy( voice ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/" )
def prepare_dataset_proxy( voice, language ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
prepare_dataset_button.click(
prepare_dataset_proxy,
@ -403,9 +404,9 @@ def setup_gradio():
training_settings = training_settings + [
gr.Textbox(label="Training Name", placeholder="finetune"),
gr.Textbox(label="Dataset Name", placeholder="finetune"),
gr.Textbox(label="Dataset Path", placeholder="./experiments/finetune/train.txt"),
gr.Textbox(label="Dataset Path", placeholder="./training/finetune/train.txt"),
gr.Textbox(label="Validation Name", placeholder="finetune"),
gr.Textbox(label="Validation Path", placeholder="./experiments/finetune/val.txt"),
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
]
save_yaml_button.click(save_training_settings,

View File

@ -24,6 +24,7 @@ datasets:
num_conditioning_candidates: 2
conditioning_length: 44000
use_bpe_tokenizer: True
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
load_aligned_codes: False
val:
name: ${validation_name}
@ -40,6 +41,7 @@ datasets:
num_conditioning_candidates: 2
conditioning_length: 44000
use_bpe_tokenizer: True
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
load_aligned_codes: False
steps:
@ -59,20 +61,20 @@ steps:
injectors: # TODO: replace this entire sequence with the GptVoiceLatentInjector
paired_to_mel:
type: torch_mel_spectrogram
mel_norm_file: ./experiments/clips_mel_norms.pth
mel_norm_file: ./models/tortoise/clips_mel_norms.pth
in: wav
out: paired_mel
paired_cond_to_mel:
type: for_each
subtype: torch_mel_spectrogram
mel_norm_file: ./experiments/clips_mel_norms.pth
mel_norm_file: ./models/tortoise/clips_mel_norms.pth
in: conditioning
out: paired_conditioning_mel
to_codes:
type: discrete_token
in: paired_mel
out: paired_mel_codes
dvae_config: "./experiments/train_diffusion_vocoder_22k_level.yml" # EXTREMELY IMPORTANT
dvae_config: "./models/tortoise/train_diffusion_vocoder_22k_level.yml" # EXTREMELY IMPORTANT
paired_fwd_text:
type: generator
generator: gpt
@ -112,9 +114,9 @@ networks:
#only_alignment_head: False # uv3/4
path:
pretrain_model_gpt: './experiments/autoregressive.pth' # CHANGEME: copy this from tortoise cache
pretrain_model_gpt: './models/tortoise/autoregressive.pth' # CHANGEME: copy this from tortoise cache
strict_load: true
#resume_state: ./experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
#resume_state: ./models/tortoise/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH