name: '${voice}'
model: extensibletrainer
scale: 1
gpu_ids: [0] # Manually edit this if the GPU you want to train on is not your primary, as this will set the env var that exposes CUDA devices
start_step: 0
checkpointing_enabled: true 
fp16: ${half_p}
bitsandbytes: ${bitsandbytes}
gpus: ${gpus}

datasets:
  train:
    name: training
    n_workers: ${workers}
    batch_size: ${batch_size}
    mode: paired_voice_audio
    path: ${dataset_path}
    fetcher_mode: ['lj']
    phase: train
    max_wav_length: 255995 # ~11.6 seconds
    max_text_length: 200
    sample_rate: 22050
    load_conditioning: True
    num_conditioning_candidates: 2
    conditioning_length: 44000
    use_bpe_tokenizer: True
    tokenizer_vocab: ${tokenizer_json} # ./models/tortoise/bpe_lowercase_asr_256.json
    load_aligned_codes: False
  val:
    name: validation
    n_workers: ${workers}
    batch_size: ${validation_batch_size}
    mode: paired_voice_audio
    path: ${validation_path}
    fetcher_mode: ['lj']
    phase: val
    max_wav_length: 255995
    max_text_length: 200
    sample_rate: 22050
    load_conditioning: True
    num_conditioning_candidates: 2
    conditioning_length: 44000
    use_bpe_tokenizer: True
    tokenizer_vocab: ${tokenizer_json} # ./models/tortoise/bpe_lowercase_asr_256.json
    load_aligned_codes: False

steps:        
  gpt_train:
    training: gpt
    loss_log_buffer: 500

    # Generally follows the recipe from the DALLE paper.
    optimizer: ${optimizer} # this should be adamw_zero if you're using distributed training
    optimizer_params:
      lr: !!float ${learning_rate} # originally: 1e-4
      weight_decay: !!float 1e-2
      beta1: 0.9
      beta2: 0.96
    clip_grad_eps: 4

    injectors:
      paired_to_mel:
        type: torch_mel_spectrogram
        mel_norm_file: ./modules/tortoise-tts/tortoise/data/mel_norms.pth # ./models/tortoise/clips_mel_norms.pth
        in: wav
        out: paired_mel
      paired_cond_to_mel:
        type: for_each
        subtype: torch_mel_spectrogram
        mel_norm_file: ./modules/tortoise-tts/tortoise/data/mel_norms.pth # ./models/tortoise/clips_mel_norms.pth
        in: conditioning
        out: paired_conditioning_mel
      to_codes:
        type: discrete_token
        in: paired_mel
        out: paired_mel_codes
        dvae_config: "./models/tortoise/train_diffusion_vocoder_22k_level.yml"
      paired_fwd_text:
        type: generator
        generator: gpt
        in: [paired_conditioning_mel, padded_text, text_lengths, paired_mel_codes, wav_lengths]
        out: [loss_text_ce, loss_mel_ce, logits]      
    losses:
      text_ce:
        type: direct
        weight: ${text_lr_weight}
        key: loss_text_ce
      mel_ce:
        type: direct
        weight: ${mel_lr_weight}
        key: loss_mel_ce

networks:
  gpt:
    type: generator
    which_model_G: unified_voice2
    kwargs:
      layers: 30 # originally: 8
      model_dim: 1024 # originally: 512
      heads: 16 # originally: 8
      max_text_tokens: 402 # originally: 120
      max_mel_tokens: 604 # originally: 250
      max_conditioning_inputs: 2 # originally: 1
      mel_length_compression: 1024
      number_text_tokens: 256 # supposed to be 255 for newer unified_voice files 
      number_mel_codes: 8194
      start_mel_token: 8192
      stop_mel_token: 8193
      start_text_token: 255
      train_solo_embeddings: False # missing in uv3/4
      use_mel_codes_as_input: True # ditto
      checkpointing: True
      tortoise_compat: True
      # freeze_everything_but_position_embeddings: True

path:
  strict_load: true
  ${source_model} 
  ${resume_state}

train:
  niter: ${iterations}
  warmup_iter: -1
  mega_batch_factor: ${gradient_accumulation_size}
  val_freq: ${validation_rate}

  ema_enabled: false # I really don't think EMA matters

  ${learning_rate_scheme}

eval:
  pure: ${validation_enabled}
  output_state: gen

logger: 
  save_checkpoint_freq: ${save_rate}
  visuals: [gen, mel]
  visual_debug_rate: ${save_rate}
  is_mel_spectrogram: true