forked from mrq/ai-voice-cloning
142 lines
4.2 KiB
YAML
Executable File
142 lines
4.2 KiB
YAML
Executable File
name: ${name}
|
|
model: extensibletrainer
|
|
scale: 1
|
|
gpu_ids: [0] # Superfluous, redundant, unnecessary, the way you launch the training script will set this
|
|
start_step: 0
|
|
checkpointing_enabled: true
|
|
fp16: ${float16}
|
|
wandb: false
|
|
use_tb_logger: true
|
|
|
|
datasets:
|
|
train:
|
|
name: ${dataset_name}
|
|
n_workers: ${workers}
|
|
batch_size: ${batch_size}
|
|
mode: paired_voice_audio
|
|
path: ${dataset_path}
|
|
fetcher_mode: ['lj']
|
|
phase: train
|
|
max_wav_length: 255995
|
|
max_text_length: 200
|
|
sample_rate: 22050
|
|
load_conditioning: True
|
|
num_conditioning_candidates: 2
|
|
conditioning_length: 44000
|
|
use_bpe_tokenizer: True
|
|
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
|
|
load_aligned_codes: False
|
|
val: # I really do not care about validation right now
|
|
name: ${validation_name}
|
|
n_workers: ${workers}
|
|
batch_size: ${batch_size}
|
|
mode: paired_voice_audio
|
|
path: ${validation_path}
|
|
fetcher_mode: ['lj']
|
|
phase: val
|
|
max_wav_length: 255995
|
|
max_text_length: 200
|
|
sample_rate: 22050
|
|
load_conditioning: True
|
|
num_conditioning_candidates: 2
|
|
conditioning_length: 44000
|
|
use_bpe_tokenizer: True
|
|
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
|
|
load_aligned_codes: False
|
|
|
|
steps:
|
|
gpt_train:
|
|
training: gpt
|
|
loss_log_buffer: 500
|
|
|
|
# Generally follows the recipe from the DALLE paper.
|
|
optimizer: adamw # this should be adamw_zero if you're using distributed training
|
|
optimizer_params:
|
|
lr: !!float ${learning_rate} # originally: 1e-4
|
|
weight_decay: !!float 1e-2
|
|
beta1: 0.9
|
|
beta2: 0.96
|
|
clip_grad_eps: 4
|
|
|
|
injectors:
|
|
paired_to_mel:
|
|
type: torch_mel_spectrogram
|
|
mel_norm_file: ./models/tortoise/clips_mel_norms.pth
|
|
in: wav
|
|
out: paired_mel
|
|
paired_cond_to_mel:
|
|
type: for_each
|
|
subtype: torch_mel_spectrogram
|
|
mel_norm_file: ./models/tortoise/clips_mel_norms.pth
|
|
in: conditioning
|
|
out: paired_conditioning_mel
|
|
to_codes:
|
|
type: discrete_token
|
|
in: paired_mel
|
|
out: paired_mel_codes
|
|
dvae_config: "./models/tortoise/train_diffusion_vocoder_22k_level.yml"
|
|
paired_fwd_text:
|
|
type: generator
|
|
generator: gpt
|
|
in: [paired_conditioning_mel, padded_text, text_lengths, paired_mel_codes, wav_lengths]
|
|
out: [loss_text_ce, loss_mel_ce, logits]
|
|
losses:
|
|
text_ce:
|
|
type: direct
|
|
weight: ${text_ce_lr_weight}
|
|
key: loss_text_ce
|
|
mel_ce:
|
|
type: direct
|
|
weight: 1
|
|
key: loss_mel_ce
|
|
|
|
networks:
|
|
gpt:
|
|
type: generator
|
|
which_model_G: unified_voice2 # none of the unified_voice*.py files actually match the tortoise inference code... 4 and 3 have "alignment_head" (wtf is that?), 2 lacks the types=1 parameter.
|
|
kwargs:
|
|
layers: 30 # originally: 8
|
|
model_dim: 1024 # originally: 512
|
|
heads: 16 # originally: 8
|
|
max_text_tokens: 402 # originally: 120
|
|
max_mel_tokens: 604 # originally: 250
|
|
max_conditioning_inputs: 2 # originally: 1
|
|
mel_length_compression: 1024
|
|
number_text_tokens: 256 # supposed to be 255 for newer unified_voice files
|
|
number_mel_codes: 8194
|
|
start_mel_token: 8192
|
|
stop_mel_token: 8193
|
|
start_text_token: 255
|
|
train_solo_embeddings: False # missing in uv3/4
|
|
use_mel_codes_as_input: True # ditto
|
|
checkpointing: True
|
|
#types: 1 # this is MISSING, but in my analysis 1 is equivalent to not having it.
|
|
#only_alignment_head: False # uv3/4
|
|
|
|
path:
|
|
${pretrain_model_gpt}
|
|
strict_load: true
|
|
${resume_state}
|
|
|
|
train:
|
|
niter: ${iterations}
|
|
warmup_iter: -1
|
|
mega_batch_factor: ${gradient_accumulation_size}
|
|
val_freq: ${validation_rate}
|
|
|
|
ema_enabled: false # I really don't think EMA matters
|
|
|
|
default_lr_scheme: MultiStepLR
|
|
gen_lr_steps: ${gen_lr_steps} #[50000, 100000, 140000, 180000]
|
|
lr_gamma: 0.5
|
|
|
|
eval:
|
|
pure: True
|
|
output_state: gen
|
|
|
|
logger:
|
|
print_freq: ${print_rate}
|
|
save_checkpoint_freq: ${save_rate}
|
|
visuals: [gen, mel]
|
|
visual_debug_rate: ${print_rate}
|
|
is_mel_spectrogram: true |