1
0

added mel LR weight (as I finally understand when to adjust the text), added text validation on dataset creation

This commit is contained in:
mrq 2023-03-13 18:51:53 +00:00
parent ee1b048d07
commit 66ac8ba766
3 changed files with 18 additions and 7 deletions

View File

@ -1,7 +1,7 @@
name: '${voice}'
model: extensibletrainer
scale: 1
gpu_ids: [0] # Superfluous, redundant, unnecessary, the way you launch the training script will set this
gpu_ids: [0] # Manually edit this if the GPU you want to train on is not your primary, as this will set the env var that exposes CUDA devices
start_step: 0
checkpointing_enabled: true
fp16: ${half_p}
@ -17,7 +17,7 @@ datasets:
path: ${dataset_path}
fetcher_mode: ['lj']
phase: train
max_wav_length: 255995
max_wav_length: 255995 # ~11.6 seconds
max_text_length: 200
sample_rate: 22050
load_conditioning: True
@ -26,7 +26,7 @@ datasets:
use_bpe_tokenizer: True
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
load_aligned_codes: False
val: # I really do not care about validation right now
val:
name: validation
n_workers: ${workers}
batch_size: ${validation_batch_size}
@ -83,11 +83,11 @@ steps:
losses:
text_ce:
type: direct
weight: ${text_ce_lr_weight}
weight: ${text_lr_weight}
key: loss_text_ce
mel_ce:
type: direct
weight: 1
weight: ${mel_lr_weight}
key: loss_mel_ce
networks:

View File

@ -1212,6 +1212,14 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
result = results[filename]
use_segment = use_segments
# check if unsegmented text exceeds 200 characters
if not use_segment:
if len(result['text']) > 200:
message = f"Text length too long (200 < {len(text)}), using segments: {filename}"
print(message)
messages.append(message)
use_segment = True
# check if unsegmented audio exceeds 11.6s
if not use_segment:
path = f'{indir}/audio/{filename}'
@ -1254,6 +1262,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
print(message)
messages.append(message)
continue
waveform, sample_rate = torchaudio.load(path)
@ -1340,7 +1349,8 @@ def optimize_training_settings( **kwargs ):
def get_device_batch_size( vram ):
DEVICE_BATCH_SIZE_MAP = [
(32, 64), # based on my two 6800XTs, I can only really safely get a ratio of 156:2 = 78
(70, 128), # based on an A100-80G, I can safely get a ratio of 4096:32 = 128
(32, 64), # based on my two 6800XTs, I can only really safely get a ratio of 128:2 = 64
(16, 8), # based on an A4000, I can do a ratio of 512:64 = 8:1
(8, 4), # interpolated
(6, 2), # based on my 2060, it only really lets me have a batch ratio of 2:1

View File

@ -446,7 +446,8 @@ def setup_gradio():
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
with gr.Row():
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1)
TRAINING_SETTINGS["mel_lr_weight"] = gr.Slider(label="Mel LR Ratio", value=1.00, minimum=0, maximum=1)
TRAINING_SETTINGS["text_lr_weight"] = gr.Slider(label="Text LR Ratio", value=0.01, minimum=0, maximum=1)
with gr.Row():
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())