forked from mrq/ai-voice-cloning
added mel LR weight (as I finally understand when to adjust the text), added text validation on dataset creation
This commit is contained in:
parent
ee1b048d07
commit
66ac8ba766
|
@ -1,7 +1,7 @@
|
||||||
name: '${voice}'
|
name: '${voice}'
|
||||||
model: extensibletrainer
|
model: extensibletrainer
|
||||||
scale: 1
|
scale: 1
|
||||||
gpu_ids: [0] # Superfluous, redundant, unnecessary, the way you launch the training script will set this
|
gpu_ids: [0] # Manually edit this if the GPU you want to train on is not your primary, as this will set the env var that exposes CUDA devices
|
||||||
start_step: 0
|
start_step: 0
|
||||||
checkpointing_enabled: true
|
checkpointing_enabled: true
|
||||||
fp16: ${half_p}
|
fp16: ${half_p}
|
||||||
|
@ -17,7 +17,7 @@ datasets:
|
||||||
path: ${dataset_path}
|
path: ${dataset_path}
|
||||||
fetcher_mode: ['lj']
|
fetcher_mode: ['lj']
|
||||||
phase: train
|
phase: train
|
||||||
max_wav_length: 255995
|
max_wav_length: 255995 # ~11.6 seconds
|
||||||
max_text_length: 200
|
max_text_length: 200
|
||||||
sample_rate: 22050
|
sample_rate: 22050
|
||||||
load_conditioning: True
|
load_conditioning: True
|
||||||
|
@ -26,7 +26,7 @@ datasets:
|
||||||
use_bpe_tokenizer: True
|
use_bpe_tokenizer: True
|
||||||
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
|
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
|
||||||
load_aligned_codes: False
|
load_aligned_codes: False
|
||||||
val: # I really do not care about validation right now
|
val:
|
||||||
name: validation
|
name: validation
|
||||||
n_workers: ${workers}
|
n_workers: ${workers}
|
||||||
batch_size: ${validation_batch_size}
|
batch_size: ${validation_batch_size}
|
||||||
|
@ -83,11 +83,11 @@ steps:
|
||||||
losses:
|
losses:
|
||||||
text_ce:
|
text_ce:
|
||||||
type: direct
|
type: direct
|
||||||
weight: ${text_ce_lr_weight}
|
weight: ${text_lr_weight}
|
||||||
key: loss_text_ce
|
key: loss_text_ce
|
||||||
mel_ce:
|
mel_ce:
|
||||||
type: direct
|
type: direct
|
||||||
weight: 1
|
weight: ${mel_lr_weight}
|
||||||
key: loss_mel_ce
|
key: loss_mel_ce
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
|
|
12
src/utils.py
12
src/utils.py
|
@ -1212,6 +1212,14 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
use_segment = use_segments
|
use_segment = use_segments
|
||||||
|
|
||||||
|
# check if unsegmented text exceeds 200 characters
|
||||||
|
if not use_segment:
|
||||||
|
if len(result['text']) > 200:
|
||||||
|
message = f"Text length too long (200 < {len(text)}), using segments: {filename}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
use_segment = True
|
||||||
|
|
||||||
# check if unsegmented audio exceeds 11.6s
|
# check if unsegmented audio exceeds 11.6s
|
||||||
if not use_segment:
|
if not use_segment:
|
||||||
path = f'{indir}/audio/{filename}'
|
path = f'{indir}/audio/{filename}'
|
||||||
|
@ -1254,6 +1262,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
continue
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(path)
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
|
|
||||||
|
@ -1340,7 +1349,8 @@ def optimize_training_settings( **kwargs ):
|
||||||
|
|
||||||
def get_device_batch_size( vram ):
|
def get_device_batch_size( vram ):
|
||||||
DEVICE_BATCH_SIZE_MAP = [
|
DEVICE_BATCH_SIZE_MAP = [
|
||||||
(32, 64), # based on my two 6800XTs, I can only really safely get a ratio of 156:2 = 78
|
(70, 128), # based on an A100-80G, I can safely get a ratio of 4096:32 = 128
|
||||||
|
(32, 64), # based on my two 6800XTs, I can only really safely get a ratio of 128:2 = 64
|
||||||
(16, 8), # based on an A4000, I can do a ratio of 512:64 = 8:1
|
(16, 8), # based on an A4000, I can do a ratio of 512:64 = 8:1
|
||||||
(8, 4), # interpolated
|
(8, 4), # interpolated
|
||||||
(6, 2), # based on my 2060, it only really lets me have a batch ratio of 2:1
|
(6, 2), # based on my 2060, it only really lets me have a batch ratio of 2:1
|
||||||
|
|
|
@ -446,7 +446,8 @@ def setup_gradio():
|
||||||
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
|
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
|
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
|
||||||
TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1)
|
TRAINING_SETTINGS["mel_lr_weight"] = gr.Slider(label="Mel LR Ratio", value=1.00, minimum=0, maximum=1)
|
||||||
|
TRAINING_SETTINGS["text_lr_weight"] = gr.Slider(label="Text LR Ratio", value=0.01, minimum=0, maximum=1)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
|
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user