tab to generate the training YAML

This commit is contained in:
mrq 2023-02-17 03:05:27 +00:00
parent 3a078df95e
commit f8249aa826
5 changed files with 543 additions and 298 deletions

View File

@ -221,6 +221,27 @@ If you want to reuse its generation settings, simply click `Copy Settings`.
To import a voice, click `Import Voice`. Remember to click `Refresh Voice List` in the `Generate` panel afterwards, if it's a new voice.
### Training
This tab will contain a collection of sub-tabs pertaining to training.
#### Configuration
This will generate the YAML necessary to feed into training. For now, you can set:
* `Batch Size`: size of batches for training, more batches = faster training, at the cost of higher VRAM. setting this to 1 will lead to problems
* `Learning Rate`: how large changes to training will be made, lower values = better over the long term, while higher values will fry a model fast. For fine-tuning, the default *should* be fine, but in the future, a learning rate scheduler would be better (have a higher learning rate initially, then step it down over enough steps/epochs)
* `Print Frequency`: how often to print (I assume)
* `Save Frequency`: how often to save checkpoints
* `Training Name`: name to save the configuration as, as well as the training script to create the folder under
* `Dataset Name`: **!**TODO**!**: fill
* `Dataset Path`: path to the input training text file. For LJSpeech-esque datasets, this is to a textfile formatted like:
```
wavs/LJ001-0001.wav|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition
wavs/LJ001-0002.wav|in being comparatively modern.|in being comparatively modern.
```
* `Validation Name`: **!**TODO**!**: fill
* `Validation Path`: path for the validation set, similar to the dataset. I'm not necessarily sure what to really use for this, so explicitly for testing, I just copied the training dataset text
### Settings
This tab (should) hold a bunch of other settings, from tunables that shouldn't be tampered with, to settings pertaining to the web UI itself.

35
src/train.py Executable file
View File

@ -0,0 +1,35 @@
import torch
import argparse
from ..dlas.codes import *
from ..dlas.codes.utils import util, options as option
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(args.opt, opt, args.launcher)
trainer.do_training()

View File

@ -138,7 +138,7 @@ def generate(
try:
tts
except NameError:
raise gr.Error("TTS is still initializing...")
raise Exception("TTS is still initializing...")
if voice != "microphone":
voices = [voice]
@ -147,7 +147,7 @@ def generate(
if voice == "microphone":
if mic_audio is None:
raise gr.Error("Please provide audio from mic when choosing `microphone` as a voice input")
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
mic = load_audio(mic_audio, tts.input_sample_rate)
voice_samples, conditioning_latents = [mic], None
elif voice == "random":
@ -431,4 +431,250 @@ def setup_tortoise(restart=False):
print("Initializating TorToiSe...")
tts = TextToSpeech(minor_optimizations=not args.low_vram)
print("TorToiSe initialized, ready for generation.")
return tts
return tts
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ):
settings = {
"batch_size": batch_size if batch_size else 128,
"learning_rate": learning_rate if learning_rate else 1e-5,
"print_rate": print_rate if print_rate else 50,
"save_rate": save_rate if save_rate else 50,
"name": name if name else "finetune",
"dataset_name": dataset_name if dataset_name else "finetune",
"dataset_path": dataset_path if dataset_path else "./experiments/finetune/train.txt",
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./experiments/finetune/val.txt",
}
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read()
for k in settings:
print(f"${{{k}}} => {settings[k]}")
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
f.write(yaml)
def reset_generation_settings():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') )
return import_generate_settings()
def import_voice(file, saveAs = None):
global args
j, latents = read_generate_settings(file, read_latents=True)
if j is not None and saveAs is None:
saveAs = j['voice']
if saveAs is None or saveAs == "":
raise Exception("Specify a voice name")
outdir = f'{get_voice_dir()}/{saveAs}/'
os.makedirs(outdir, exist_ok=True)
if latents:
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
f.write(latents)
latents = f'{outdir}/cond_latents.pth'
print(f"Imported latents to {latents}")
else:
filename = file.name
if filename[-4:] != ".wav":
raise Exception("Please convert to a WAV first")
path = f"{outdir}/{os.path.basename(filename)}"
waveform, sampling_rate = torchaudio.load(filename)
if args.voice_fixer:
# resample to best bandwidth since voicefixer will do it anyways through librosa
if sampling_rate != 44100:
print(f"Resampling imported voice sample: {path}")
resampler = torchaudio.transforms.Resample(
sampling_rate,
44100,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
waveform = resampler(waveform)
sampling_rate = 44100
torchaudio.save(path, waveform, sampling_rate)
print(f"Running 'voicefixer' on voice sample: {path}")
voicefixer.restore(
input = path,
output = path,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
else:
torchaudio.save(path, waveform, sampling_rate)
print(f"Imported voice to {path}")
def import_generate_settings(file="./config/generate.json"):
settings, _ = read_generate_settings(file, read_latents=False)
if settings is None:
return None
return (
None if 'text' not in settings else settings['text'],
None if 'delimiter' not in settings else settings['delimiter'],
None if 'emotion' not in settings else settings['emotion'],
None if 'prompt' not in settings else settings['prompt'],
None if 'voice' not in settings else settings['voice'],
None,
None,
None if 'seed' not in settings else settings['seed'],
None if 'candidates' not in settings else settings['candidates'],
None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'],
None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'],
0.8 if 'temperature' not in settings else settings['temperature'],
"DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'],
8 if 'breathing_room' not in settings else settings['breathing_room'],
0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'],
0.8 if 'top_p' not in settings else settings['top_p'],
1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'],
1.0 if 'length_penalty' not in settings else settings['length_penalty'],
2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'],
2.0 if 'cond_free_k' not in settings else settings['cond_free_k'],
None if 'experimentals' not in settings else settings['experimentals'],
)
def curl(url):
try:
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
conn = urllib.request.urlopen(req)
data = conn.read()
data = data.decode()
data = json.loads(data)
conn.close()
return data
except Exception as e:
print(e)
return None
def check_for_updates():
if not os.path.isfile('./.git/FETCH_HEAD'):
print("Cannot check for updates: not from a git repo")
return False
with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
head = f.read()
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
if match is None or len(match) == 0:
print("Cannot check for updates: cannot parse FETCH_HEAD")
return False
match = match[0]
local = match[0]
host = match[1]
owner = match[2]
repo = match[3]
res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances
if res is None or len(res) == 0:
print("Cannot check for updates: cannot fetch from remote")
return False
remote = res[0]["commit"]["id"]
if remote != local:
print(f"New version found: {local[:8]} => {remote[:8]}")
return True
return False
def reload_tts():
global tts
del tts
tts = setup_tortoise(restart=True)
def cancel_generate():
tortoise.api.STOP_SIGNAL = True
def get_voice_list(dir=get_voice_dir()):
os.makedirs(dir, exist_ok=True)
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
global args
args.listen = listen
args.share = share
args.check_for_updates = check_for_updates
args.models_from_local_only = models_from_local_only
args.low_vram = low_vram
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args.device_override = device_override
args.sample_batch_size = sample_batch_size
args.embed_output_metadata = embed_output_metadata
args.latents_lean_and_mean = latents_lean_and_mean
args.voice_fixer = voice_fixer
args.voice_fixer_use_cuda = voice_fixer_use_cuda
args.concurrency_count = concurrency_count
args.output_sample_rate = output_sample_rate
args.output_volume = output_volume
settings = {
'listen': None if args.listen else args.listen,
'share': args.share,
'low-vram':args.low_vram,
'check-for-updates':args.check_for_updates,
'models-from-local-only':args.models_from_local_only,
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
'device-override': args.device_override,
'sample-batch-size': args.sample_batch_size,
'embed-output-metadata': args.embed_output_metadata,
'latents-lean-and-mean': args.latents_lean_and_mean,
'voice-fixer': args.voice_fixer,
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
'concurrency-count': args.concurrency_count,
'output-sample-rate': args.output_sample_rate,
'output-volume': args.output_volume,
}
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(settings, indent='\t') )
def read_generate_settings(file, read_latents=True, read_json=True):
j = None
latents = None
if file is not None:
if hasattr(file, 'name'):
file = file.name
if file[-4:] == ".wav":
metadata = music_tag.load_file(file)
if 'lyrics' in metadata:
j = json.loads(str(metadata['lyrics']))
elif file[-5:] == ".json":
with open(file, 'r') as f:
j = json.load(f)
if j is None:
print("No metadata found in audio file to read")
else:
if 'latents' in j:
if read_latents:
latents = base64.b64decode(j['latents'])
del j['latents']
if "time" in j:
j["time"] = "{:.3f}".format(j["time"])
return (
j,
latents,
)

View File

@ -21,6 +21,71 @@ from utils import *
args = setup_args()
def run_generation(
text,
delimiter,
emotion,
prompt,
voice,
mic_audio,
voice_latents_chunks,
seed,
candidates,
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
progress=gr.Progress(track_tqdm=True)
):
try:
sample, outputs, stats = generate(
text,
delimiter,
emotion,
prompt,
voice,
mic_audio,
voice_latents_chunks,
seed,
candidates,
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
progress
)
except Exception as e:
message = str(e)
if message == "Kill signal detected":
reload_tts()
raise gr.Error(message)
return (
outputs[0],
gr.update(value=sample, visible=sample is not None),
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
gr.update(visible=len(outputs) > 1),
gr.update(value=stats, visible=True),
)
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts
global args
@ -58,230 +123,6 @@ def update_presets(value):
else:
return (gr.update(), gr.update())
def read_generate_settings(file, read_latents=True, read_json=True):
j = None
latents = None
if file is not None:
if hasattr(file, 'name'):
file = file.name
if file[-4:] == ".wav":
metadata = music_tag.load_file(file)
if 'lyrics' in metadata:
j = json.loads(str(metadata['lyrics']))
elif file[-5:] == ".json":
with open(file, 'r') as f:
j = json.load(f)
if j is None:
gr.Error("No metadata found in audio file to read")
else:
if 'latents' in j:
if read_latents:
latents = base64.b64decode(j['latents'])
del j['latents']
if "time" in j:
j["time"] = "{:.3f}".format(j["time"])
return (
j,
latents,
)
def import_voice(file, saveAs = None):
global args
j, latents = read_generate_settings(file, read_latents=True)
if j is not None and saveAs is None:
saveAs = j['voice']
if saveAs is None or saveAs == "":
raise gr.Error("Specify a voice name")
outdir = f'{get_voice_dir()}/{saveAs}/'
os.makedirs(outdir, exist_ok=True)
if latents:
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
f.write(latents)
latents = f'{outdir}/cond_latents.pth'
print(f"Imported latents to {latents}")
else:
filename = file.name
if filename[-4:] != ".wav":
raise gr.Error("Please convert to a WAV first")
path = f"{outdir}/{os.path.basename(filename)}"
waveform, sampling_rate = torchaudio.load(filename)
if args.voice_fixer:
# resample to best bandwidth since voicefixer will do it anyways through librosa
if sampling_rate != 44100:
print(f"Resampling imported voice sample: {path}")
resampler = torchaudio.transforms.Resample(
sampling_rate,
44100,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
waveform = resampler(waveform)
sampling_rate = 44100
torchaudio.save(path, waveform, sampling_rate)
print(f"Running 'voicefixer' on voice sample: {path}")
voicefixer.restore(
input = path,
output = path,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
else:
torchaudio.save(path, waveform, sampling_rate)
print(f"Imported voice to {path}")
def import_generate_settings(file="./config/generate.json"):
settings, _ = read_generate_settings(file, read_latents=False)
if settings is None:
return None
return (
None if 'text' not in settings else settings['text'],
None if 'delimiter' not in settings else settings['delimiter'],
None if 'emotion' not in settings else settings['emotion'],
None if 'prompt' not in settings else settings['prompt'],
None if 'voice' not in settings else settings['voice'],
None,
None,
None if 'seed' not in settings else settings['seed'],
None if 'candidates' not in settings else settings['candidates'],
None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'],
None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'],
0.8 if 'temperature' not in settings else settings['temperature'],
"DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'],
8 if 'breathing_room' not in settings else settings['breathing_room'],
0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'],
0.8 if 'top_p' not in settings else settings['top_p'],
1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'],
1.0 if 'length_penalty' not in settings else settings['length_penalty'],
2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'],
2.0 if 'cond_free_k' not in settings else settings['cond_free_k'],
None if 'experimentals' not in settings else settings['experimentals'],
)
def curl(url):
try:
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
conn = urllib.request.urlopen(req)
data = conn.read()
data = data.decode()
data = json.loads(data)
conn.close()
return data
except Exception as e:
print(e)
return None
def check_for_updates():
if not os.path.isfile('./.git/FETCH_HEAD'):
print("Cannot check for updates: not from a git repo")
return False
with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
head = f.read()
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
if match is None or len(match) == 0:
print("Cannot check for updates: cannot parse FETCH_HEAD")
return False
match = match[0]
local = match[0]
host = match[1]
owner = match[2]
repo = match[3]
res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances
if res is None or len(res) == 0:
print("Cannot check for updates: cannot fetch from remote")
return False
remote = res[0]["commit"]["id"]
if remote != local:
print(f"New version found: {local[:8]} => {remote[:8]}")
return True
return False
def reload_tts():
global tts
del tts
tts = setup_tortoise(restart=True)
def cancel_generate():
tortoise.api.STOP_SIGNAL = True
def get_voice_list(dir=get_voice_dir()):
os.makedirs(dir, exist_ok=True)
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
def update_voices():
return (
gr.Dropdown.update(choices=get_voice_list()),
gr.Dropdown.update(choices=get_voice_list("./results/")),
)
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
global args
args.listen = listen
args.share = share
args.check_for_updates = check_for_updates
args.models_from_local_only = models_from_local_only
args.low_vram = low_vram
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args.device_override = device_override
args.sample_batch_size = sample_batch_size
args.embed_output_metadata = embed_output_metadata
args.latents_lean_and_mean = latents_lean_and_mean
args.voice_fixer = voice_fixer
args.voice_fixer_use_cuda = voice_fixer_use_cuda
args.concurrency_count = concurrency_count
args.output_sample_rate = output_sample_rate
args.output_volume = output_volume
settings = {
'listen': None if args.listen else args.listen,
'share': args.share,
'low-vram':args.low_vram,
'check-for-updates':args.check_for_updates,
'models-from-local-only':args.models_from_local_only,
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
'device-override': args.device_override,
'sample-batch-size': args.sample_batch_size,
'embed-output-metadata': args.embed_output_metadata,
'latents-lean-and-mean': args.latents_lean_and_mean,
'voice-fixer': args.voice_fixer,
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
'concurrency-count': args.concurrency_count,
'output-sample-rate': args.output_sample_rate,
'output-volume': args.output_volume,
}
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(settings, indent='\t') )
def setup_gradio():
global args
global ui
@ -528,6 +369,29 @@ def setup_gradio():
import_voice_name,
]
)
with gr.Tab("Training"):
with gr.Tab("Configuration"):
with gr.Row():
with gr.Column():
training_settings = [
gr.Slider(label="Batch Size", value=128),
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50),
]
with gr.Column():
training_settings = training_settings + [
gr.Textbox(label="Training Name", placeholder="finetune"),
gr.Textbox(label="Dataset Name", placeholder="finetune"),
gr.Textbox(label="Dataset Path", placeholder="./experiments/finetune/train.txt"),
gr.Textbox(label="Validation Name", placeholder="finetune"),
gr.Textbox(label="Validation Path", placeholder="./experiments/finetune/val.txt"),
]
save_yaml_button = gr.Button(value="Save Training Configuration")
save_yaml_button.click(save_training_settings,
inputs=training_settings,
outputs=None
)
with gr.Tab("Settings"):
with gr.Row():
exec_inputs = []
@ -586,71 +450,15 @@ def setup_gradio():
]
# YUCK
def run_generation(
text,
delimiter,
emotion,
prompt,
voice,
mic_audio,
voice_latents_chunks,
seed,
candidates,
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
progress=gr.Progress(track_tqdm=True)
):
try:
sample, outputs, stats = generate(
text,
delimiter,
emotion,
prompt,
voice,
mic_audio,
voice_latents_chunks,
seed,
candidates,
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
progress
)
except Exception as e:
message = str(e)
if message == "Kill signal detected":
reload_tts()
raise gr.Error(message)
def update_voices():
return (
outputs[0],
gr.update(value=sample, visible=sample is not None),
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
gr.update(visible=len(outputs) > 1),
gr.update(value=stats, visible=True),
gr.Dropdown.update(choices=get_voice_list()),
gr.Dropdown.update(choices=get_voice_list("./results/")),
)
def history_copy_settings( voice, file ):
return import_generate_settings( f"./results/{voice}/{file}" )
refresh_voices.click(update_voices,
inputs=None,
outputs=[
@ -681,21 +489,12 @@ def setup_gradio():
outputs=input_settings
)
def reset_generation_settings():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') )
return import_generate_settings()
reset_generation_settings_button.click(
fn=reset_generation_settings,
inputs=None,
outputs=input_settings
)
def history_copy_settings( voice, file ):
settings = import_generate_settings( f"./results/{voice}/{file}" )
return settings
history_copy_settings_button.click(history_copy_settings,
inputs=[
history_voices,

144
training/.template.yaml Executable file
View File

@ -0,0 +1,144 @@
name: ${name}
model: extensibletrainer
scale: 1
gpu_ids: [0] # <-- unless you have multiple gpus, use this
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
fp16: false # might want to check this out
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
use_tb_logger: true
datasets:
train:
name: ${dataset_name}
n_workers: 8 # idk what this does
batch_size: ${batch_size} # This leads to ~16GB of vram usage on my 3090.
mode: paired_voice_audio
path: ${dataset_path}
fetcher_mode: ['lj'] # CHANGEME if your dataset isn't in LJSpeech format
phase: train
max_wav_length: 255995
max_text_length: 200
sample_rate: 22050
load_conditioning: True
num_conditioning_candidates: 2
conditioning_length: 44000
use_bpe_tokenizer: True
load_aligned_codes: False
val:
name: ${validation_name}
n_workers: 1
batch_size: 32 # this could be higher probably
mode: paired_voice_audio
path: ${validation_path}
fetcher_mode: ['lj']
phase: val # might be broken idk
max_wav_length: 255995
max_text_length: 200
sample_rate: 22050
load_conditioning: True
num_conditioning_candidates: 2
conditioning_length: 44000
use_bpe_tokenizer: True
load_aligned_codes: False
steps:
gpt_train:
training: gpt
loss_log_buffer: 500 # no idea what this does
# Generally follows the recipe from the DALLE paper.
optimizer: adamw # this should be adamw_zero if you're using distributed training
optimizer_params:
lr: !!float ${learning_rate} # CHANGEME: this was originally 1e-4; I reduced it to 1e-5 because it's fine-tuning, but **you should experiment with this value**
weight_decay: !!float 1e-2
beta1: 0.9
beta2: 0.96
clip_grad_eps: 4
injectors: # TODO: replace this entire sequence with the GptVoiceLatentInjector
paired_to_mel:
type: torch_mel_spectrogram
mel_norm_file: ./experiments/clips_mel_norms.pth
in: wav
out: paired_mel
paired_cond_to_mel:
type: for_each
subtype: torch_mel_spectrogram
mel_norm_file: ./experiments/clips_mel_norms.pth
in: conditioning
out: paired_conditioning_mel
to_codes:
type: discrete_token
in: paired_mel
out: paired_mel_codes
dvae_config: "./experiments/train_diffusion_vocoder_22k_level.yml" # EXTREMELY IMPORTANT
paired_fwd_text:
type: generator
generator: gpt
in: [paired_conditioning_mel, padded_text, text_lengths, paired_mel_codes, wav_lengths]
out: [loss_text_ce, loss_mel_ce, logits]
losses:
text_ce:
type: direct
weight: .01
key: loss_text_ce
mel_ce:
type: direct
weight: 1
key: loss_mel_ce
networks:
gpt:
type: generator
which_model_G: unified_voice2 # none of the unified_voice*.py files actually match the tortoise inference code... 4 and 3 have "alignment_head" (wtf is that?), 2 lacks the types=1 parameter.
kwargs:
layers: 30 # WAS 8
model_dim: 1024 # WAS 512
heads: 16 # WAS 8
max_text_tokens: 402 # WAS 120
max_mel_tokens: 604 # WAS 250
max_conditioning_inputs: 2 # WAS 1
mel_length_compression: 1024
number_text_tokens: 256 # supposed to be 255 for newer unified_voice files
number_mel_codes: 8194
start_mel_token: 8192
stop_mel_token: 8193
start_text_token: 255
train_solo_embeddings: False # missing in uv3/4
use_mel_codes_as_input: True # ditto
checkpointing: True
#types: 1 # this is MISSING, but in my analysis 1 is equivalent to not having it.
#only_alignment_head: False # uv3/4
path:
pretrain_model_gpt: './experiments/autoregressive.pth' # CHANGEME: copy this from tortoise cache
strict_load: true
#resume_state: ./experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
niter: 50000
warmup_iter: -1
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 500
default_lr_scheme: MultiStepLR
gen_lr_steps: [500, 1000, 1400, 1800] #[50000, 100000, 140000, 180000]
lr_gamma: 0.5
eval:
output_state: gen
injectors:
gen_inj_eval:
type: generator
generator: generator
in: hq
out: [gen, codebook_commitment_loss]
logger:
print_freq: 100
save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow
visuals: [gen, mel]
visual_debug_rate: 500
is_mel_spectrogram: true