cleanup, metrics are grabbed for vall-e trainer

This commit is contained in:
mrq 2023-03-17 05:33:49 +00:00
parent 1b72d0bba0
commit 249c6019af
3 changed files with 198 additions and 134 deletions

View File

@ -83,47 +83,47 @@
"x": 37, "x": 37,
"y": 38, "y": 38,
"z": 39, "z": 39,
"d͡": 41, "d͡": 40,
"t͡": 42, "t͡": 41,
"|": 43, "|": 42,
"æ": 44, "æ": 43,
"ð": 45, "ð": 44,
"ŋ": 46, "ŋ": 45,
"ɑ": 47, "ɑ": 46,
"ɔ": 48, "ɔ": 47,
"ə": 49, "ə": 48,
"ɚ": 50, "ɚ": 49,
"ɛ": 51, "ɛ": 50,
"ɡ": 52, "ɡ": 51,
"ɪ": 53, "ɪ": 52,
"ɹ": 54, "ɹ": 53,
"ʃ": 55, "ʃ": 54,
"ʊ": 56, "ʊ": 55,
"ʌ": 57, "ʌ": 56,
"ʒ": 58, "ʒ": 57,
"θ": 59, "θ": 58,
"ɐ": 60, "ɐ": 59,
"ɜ": 61, "ɜ": 60,
"ᵻ": 62, "ᵻ": 61,
"ɾ": 63, "ɾ": 62,
"n\u0329": 64, "n\u0329": 63,
"ː": 65, "ː": 64,
"ˈ": 66, "ˈ": 65,
"ˌ": 67, "ˌ": 66,
"ʔ": 68, "ʔ": 67,
"d͡ʒ": 69, "d͡ʒ": 68,
"aɪ": 70, "aɪ": 69,
"aʊ": 71, "aʊ": 70,
"eɪ": 72, "eɪ": 71,
"oʊ": 73, "oʊ": 72,
"t͡ʃ": 74, "t͡ʃ": 73,
"ɔɪ": 75, "ɔɪ": 74,
"ɔː": 76, "ɔː": 75,
"uː": 77, "uː": 76,
"iː": 78, "iː": 77,
"ɑː": 79, "ɑː": 78,
"oː": 80, "oː": 79,
"ɜː": 81 "ɜː": 80
}, },
"merges": [ "merges": [
"a ɪ", "a ɪ",

View File

@ -10,6 +10,7 @@ if 'TRANSFORMERS_CACHE' not in os.environ:
import argparse import argparse
import time import time
import math
import json import json
import base64 import base64
import re import re
@ -42,11 +43,6 @@ from whisper.normalizers.english import EnglishTextNormalizer
from whisper.normalizers.basic import BasicTextNormalizer from whisper.normalizers.basic import BasicTextNormalizer
from whisper.tokenizer import LANGUAGES from whisper.tokenizer import LANGUAGES
try:
from phonemizer import phonemize as phonemizer
except Exception as e:
pass
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
@ -340,6 +336,9 @@ def generate(**kwargs):
INFERENCING = True INFERENCING = True
for line, cut_text in enumerate(texts): for line, cut_text in enumerate(texts):
if should_phonemize():
cut_text = phonemizer( cut_text )
if parameters['emotion'] == "Custom": if parameters['emotion'] == "Custom":
if parameters['prompt'] and parameters['prompt'].strip() != "": if parameters['prompt'] and parameters['prompt'].strip() != "":
cut_text = f"[{parameters['prompt']},] {cut_text}" cut_text = f"[{parameters['prompt']},] {cut_text}"
@ -636,46 +635,31 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
# superfluous, but it cleans up some things # superfluous, but it cleans up some things
class TrainingState(): class TrainingState():
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True): def __init__(self, config_path, keep_x_past_checkpoints=0, start=True):
# parse config to get its iteration
with open(config_path, 'r') as file:
self.config = yaml.safe_load(file)
self.killed = False self.killed = False
self.it = 0 self.training_dir = os.path.dirname(config_path)
self.step = 0 with open(config_path, 'r') as file:
self.yaml_config = yaml.safe_load(file)
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
self.dataset_dir = f"{self.training_dir}/finetune/"
self.dataset_path = f"{self.training_dir}/train.txt"
with open(self.dataset_path, 'r', encoding="utf-8") as f:
self.dataset_size = len(f.readlines())
self.batch_size = self.json_config["batch_size"]
self.save_rate = self.json_config["save_rate"]
self.epoch = 0 self.epoch = 0
self.epochs = self.json_config["epochs"]
self.it = 0
self.its = calc_iterations( self.epochs, self.dataset_size, self.batch_size )
self.step = 0
self.steps = int(self.its / self.dataset_size)
self.checkpoint = 0 self.checkpoint = 0
self.checkpoints = int((self.its - self.it) / self.save_rate)
if args.tts_backend == "tortoise": self.gpus = self.json_config['gpus']
gpus = self.config["gpus"]
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
self.batch_size = self.config['datasets']['train']['batch_size']
self.dataset_path = self.config['datasets']['train']['path']
with open(self.dataset_path, 'r', encoding="utf-8") as f:
self.dataset_size = len(f.readlines())
self.its = self.config['train']['niter']
self.steps = 1
self.epochs = int(self.its*self.batch_size/self.dataset_size)
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
elif args.tts_backend == "vall-e":
self.batch_size = self.config['batch_size']
self.dataset_dir = f".{self.config['data_root']}/finetune/"
self.dataset_path = f"{self.config['data_root']}/train.txt"
self.its = 1
self.steps = 1
self.epochs = 1
self.checkpoints = 1
with open(self.dataset_path, 'r', encoding="utf-8") as f:
self.dataset_size = len(f.readlines())
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
gpus = self.json_config['gpus']
self.buffer = [] self.buffer = []
@ -706,12 +690,15 @@ class TrainingState():
'loss': "", 'loss': "",
} }
self.buffer_json = None
self.json_buffer = []
self.loss_milestones = [ 1.0, 0.15, 0.05 ] self.loss_milestones = [ 1.0, 0.15, 0.05 ]
if keep_x_past_checkpoints > 0: if keep_x_past_checkpoints > 0:
self.cleanup_old(keep=keep_x_past_checkpoints) self.cleanup_old(keep=keep_x_past_checkpoints)
if start: if start:
self.spawn_process(config_path=config_path, gpus=gpus) self.spawn_process(config_path=config_path, gpus=self.gpus)
def spawn_process(self, config_path, gpus=1): def spawn_process(self, config_path, gpus=1):
if args.tts_backend == "vall-e": if args.tts_backend == "vall-e":
@ -771,6 +758,7 @@ class TrainingState():
if 'lr' in self.info: if 'lr' in self.info:
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'}) self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
if args.tts_backend == "tortoise":
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']: for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
if k not in self.info: if k not in self.info:
continue continue
@ -779,6 +767,10 @@ class TrainingState():
self.losses.append( self.statistics['loss'][-1] ) self.losses.append( self.statistics['loss'][-1] )
else: else:
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
else:
k = "loss"
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
self.losses.append( self.statistics['loss'][-1] )
return data return data
@ -916,18 +908,62 @@ class TrainingState():
print("Removing", path) print("Removing", path)
os.remove(path) os.remove(path)
def parse_valle_metrics(self, data):
res = {}
res['mode'] = "training"
res['loss'] = data['model.loss']
res['lr'] = data['model.lr']
res['it'] = data['global_step']
res['step'] = res['it'] % self.dataset_size
res['steps'] = self.steps
res['epoch'] = int(res['it'] / self.dataset_size)
res['iteration_rate'] = data['elapsed_time']
return res
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ): def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
self.buffer.append(f'{line}') self.buffer.append(f'{line}')
should_return = False data = None
percent = 0 percent = 0
message = None message = None
should_return = False
if line.find('Finished training') >= 0: MESSAGE_START = 'Start training from epoch'
MESSAGE_FINSIHED = 'Finished training'
MESSAGE_SAVING = 'INFO: Saving models and training states.'
MESSAGE_METRICS_TRAINING = 'INFO: Training Metrics:'
MESSAGE_METRICS_VALIDATION = 'INFO: Validation Metrics:'
if args.tts_backend == "vall-e":
if self.buffer_json:
self.json_buffer.append(line)
if line.find("{") == 0 and not self.buffer_json:
self.buffer_json = True
self.json_buffer = [line]
if line.find("}") == 0 and self.buffer_json:
try:
data = json.loads("\n".join(self.json_buffer))
except Exception as e:
print(str(e))
if data and 'model.loss' in data:
self.training_started = True
data = self.parse_valle_metrics( data )
print("Training JSON:", data)
else:
data = None
self.buffer_json = None
self.json_buffer = []
if line.find(MESSAGE_FINSIHED) >= 0:
self.killed = True self.killed = True
# rip out iteration info # rip out iteration info
elif not self.training_started: elif not self.training_started:
if line.find('Start training from epoch') >= 0: if line.find(MESSAGE_START) >= 0:
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
match = re.findall(r'epoch: ([\d,]+)', line) match = re.findall(r'epoch: ([\d,]+)', line)
@ -937,24 +973,23 @@ class TrainingState():
if match and len(match) > 0: if match and len(match) > 0:
self.it = int(match[0].replace(",", "")) self.it = int(match[0].replace(",", ""))
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq']) self.checkpoints = int((self.its - self.it) / self.save_rate)
self.load_statistics() self.load_statistics()
should_return = True should_return = True
else: else:
data = None if line.find(MESSAGE_SAVING) >= 0:
if line.find('INFO: Saving models and training states.') >= 0:
self.checkpoint += 1 self.checkpoint += 1
message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..." message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..."
percent = self.checkpoint / self.checkpoints percent = self.checkpoint / self.checkpoints
self.cleanup_old(keep=keep_x_past_checkpoints) self.cleanup_old(keep=keep_x_past_checkpoints)
elif line.find('INFO: Training Metrics:') >= 0: elif line.find(MESSAGE_METRICS_TRAINING) >= 0:
data = json.loads(line.split("INFO: Training Metrics:")[-1]) data = json.loads(line.split(MESSAGE_METRICS_TRAINING)[-1])
data['mode'] = "training" data['mode'] = "training"
elif line.find('INFO: Validation Metrics:') >= 0: elif line.find(MESSAGE_METRICS_VALIDATION) >= 0:
data = json.loads(line.split("INFO: Validation Metrics:")[-1]) data = json.loads(line.split(MESSAGE_METRICS_VALIDATION)[-1])
data['mode'] = "validation" data['mode'] = "validation"
if data is not None: if data is not None:
@ -1278,7 +1313,7 @@ def phonemize_txt_file( path ):
audio = split[0] audio = split[0]
text = split[2] text = split[2]
phonemes = phonemizer( text, preserve_punctuation=True, strip=True ) phonemes = phonemizer( text )
reparsed.append(f'{audio}|{phonemes}') reparsed.append(f'{audio}|{phonemes}')
f.write(f'\n{audio}|{phonemes}') f.write(f'\n{audio}|{phonemes}')
@ -1321,6 +1356,21 @@ def create_dataset_json( path ):
with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f: with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f:
f.write(json.dumps(data, indent="\t")) f.write(json.dumps(data, indent="\t"))
def phonemizer( text, language="en-us" ):
from phonemizer import phonemize
if language == "english":
language = "en-us"
return phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend )
def should_phonemize():
try:
from phonemizer import phonemize
except Exception as e:
print(str(e))
return False
return args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json"
def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ): def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ):
indir = f'./training/{voice}/' indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json' infile = f'{indir}/whisper.json'
@ -1332,7 +1382,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
errored = 0 errored = 0
messages = [] messages = []
normalize = True normalize = True
phonemize = args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json" phonemize = should_phonemize()
lines = { 'training': [], 'validation': [] } lines = { 'training': [], 'validation': [] }
segments = {} segments = {}
@ -1374,7 +1424,12 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if use_segment and not use_segments: if use_segment and not use_segments:
exists = True exists = True
for segment in result['segments']: for segment in result['segments']:
if os.path.exists(filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")): duration = segment['end'] - segment['start']
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
continue
path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
if os.path.exists(path):
continue continue
exists = False exists = False
break break
@ -1396,6 +1451,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
} }
else: else:
for segment in result['segments']: for segment in result['segments']:
duration = segment['end'] - segment['start']
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
continue
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = { segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
'text': segment['text'], 'text': segment['text'],
'language': language, 'language': language,
@ -1412,7 +1471,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
normalizer = result['normalizer'] normalizer = result['normalizer']
phonemes = result['phonemes'] phonemes = result['phonemes']
if phonemize and phonemes is None: if phonemize and phonemes is None:
phonemes = phonemizer( text, language=language if language != "english" else "en-us", strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend ) phonemes = phonemizer( text, language=language )
if phonemize: if phonemize:
text = phonemes text = phonemes
@ -1456,7 +1515,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
print("Quantized:", file) print("Quantized:", file)
tokens = tokenize_text(text, stringed=False, skip_specials=True) tokens = tokenize_text(text, stringed=False, skip_specials=True)
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join( tokens ).replace(" \u02C8", "\u02C8")) tokenized = " ".join( tokens )
tokenized = tokenized.replace(" \u02C8", "\u02C8")
tokenized = tokenized.replace(" \u02CC", "\u02CC")
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(tokenized)
training_joined = "\n".join(lines['training']) training_joined = "\n".join(lines['training'])
validation_joined = "\n".join(lines['validation']) validation_joined = "\n".join(lines['validation'])
@ -1471,8 +1533,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
return "\n".join(messages) return "\n".join(messages)
def calc_iterations( epochs, lines, batch_size ): def calc_iterations( epochs, lines, batch_size ):
iterations = int(epochs * lines / float(batch_size)) return int(math.ceil(epochs * math.ceil(lines / batch_size)))
return iterations
def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ): def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ):
return [int(iterations * d) for d in schedule] return [int(iterations * d) for d in schedule]
@ -1580,7 +1641,9 @@ def optimize_training_settings( **kwargs ):
if not os.path.exists(get_halfp_model_path()): if not os.path.exists(get_halfp_model_path()):
convert_to_halfp() convert_to_halfp()
messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({int(iterations / settings['epochs'])} steps per epoch)") settings['steps'] = int(iterations / settings['epochs'])
messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({settings['steps']}) steps per epoch)")
return settings, messages return settings, messages
@ -1589,6 +1652,7 @@ def save_training_settings( **kwargs ):
settings = {} settings = {}
settings.update(kwargs) settings.update(kwargs)
outjson = f'./training/{settings["voice"]}/train.json' outjson = f'./training/{settings["voice"]}/train.json'
with open(outjson, 'w', encoding="utf-8") as f: with open(outjson, 'w', encoding="utf-8") as f:
f.write(json.dumps(settings, indent='\t') ) f.write(json.dumps(settings, indent='\t') )
@ -1599,6 +1663,8 @@ def save_training_settings( **kwargs ):
with open(settings['dataset_path'], 'r', encoding="utf-8") as f: with open(settings['dataset_path'], 'r', encoding="utf-8") as f:
lines = len(f.readlines()) lines = len(f.readlines())
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
if not settings['source_model'] or settings['source_model'] == "auto": if not settings['source_model'] or settings['source_model'] == "auto":
settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth" settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth"
@ -1606,7 +1672,6 @@ def save_training_settings( **kwargs ):
if not os.path.exists(get_halfp_model_path()): if not os.path.exists(get_halfp_model_path()):
convert_to_halfp() convert_to_halfp()
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps") messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
iterations_per_epoch = settings['iterations'] / settings['epochs'] iterations_per_epoch = settings['iterations'] / settings['epochs']
@ -1622,15 +1687,14 @@ def save_training_settings( **kwargs ):
if settings['validation_rate'] < 1: if settings['validation_rate'] < 1:
settings['validation_rate'] = 1 settings['validation_rate'] = 1
""" """
"""
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
if settings['iterations'] % settings['save_rate'] != 0: if settings['iterations'] % settings['save_rate'] != 0:
adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate'] adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate']
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}") messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}")
settings['iterations'] = adjustment settings['iterations'] = adjustment
"""
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
if not os.path.exists(settings['validation_path']): if not os.path.exists(settings['validation_path']):
settings['validation_enabled'] = False settings['validation_enabled'] = False
messages.append("Validation not found, disabling validation...") messages.append("Validation not found, disabling validation...")
@ -1833,7 +1897,7 @@ def tokenize_text( text, stringed=True, skip_specials=False ):
tts.tokenizer tts.tokenizer
encoded = tokenizer.encode(text) encoded = tokenizer.encode(text)
decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=specials).split(" ") decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=skip_specials).split(" ")
if stringed: if stringed:
return "\n".join([ str(encoded), str(decoded) ]) return "\n".join([ str(encoded), str(decoded) ])

View File

@ -467,12 +467,12 @@ def setup_gradio():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0) TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
with gr.Row(): with gr.Row(visible=args.tts_backend=="tortoise"):
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6) TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
TRAINING_SETTINGS["mel_lr_weight"] = gr.Slider(label="Mel LR Ratio", value=1.00, minimum=0, maximum=1) TRAINING_SETTINGS["mel_lr_weight"] = gr.Slider(label="Mel LR Ratio", value=1.00, minimum=0, maximum=1)
TRAINING_SETTINGS["text_lr_weight"] = gr.Slider(label="Text LR Ratio", value=0.01, minimum=0, maximum=1) TRAINING_SETTINGS["text_lr_weight"] = gr.Slider(label="Text LR Ratio", value=0.01, minimum=0, maximum=1)
with gr.Row(): with gr.Row(visible=args.tts_backend=="tortoise"):
lr_schemes = list(LEARNING_RATE_SCHEMES.keys()) lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
TRAINING_SETTINGS["learning_rate_scheme"] = gr.Radio(lr_schemes, label="Learning Rate Scheme", value=lr_schemes[0], type="value") TRAINING_SETTINGS["learning_rate_scheme"] = gr.Radio(lr_schemes, label="Learning Rate Scheme", value=lr_schemes[0], type="value")
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(LEARNING_RATE_SCHEDULE), visible=True) TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(LEARNING_RATE_SCHEDULE), visible=True)
@ -488,22 +488,22 @@ def setup_gradio():
) )
with gr.Row(): with gr.Row():
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0) TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0) TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0, visible=args.tts_backend=="tortoise")
with gr.Row(): with gr.Row():
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
with gr.Row(): with gr.Row():
TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp, visible=args.tts_backend=="tortoise")
TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb, visible=args.tts_backend=="tortoise")
TRAINING_SETTINGS["validation_enabled"] = gr.Checkbox(label="Validation Enabled", value=False) TRAINING_SETTINGS["validation_enabled"] = gr.Checkbox(label="Validation Enabled", value=False)
with gr.Row(): with gr.Row():
TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0) TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0, visible=args.tts_backend=="tortoise")
TRAINING_SETTINGS["gpus"] = gr.Number(label="GPUs", value=get_device_count(), precision=0) TRAINING_SETTINGS["gpus"] = gr.Number(label="GPUs", value=get_device_count(), precision=0)
TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] ) TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0], visible=args.tts_backend=="tortoise" )
TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/finetune/training_state/${last_state}.state") TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/finetune/training_state/${last_state}.state", visible=args.tts_backend=="tortoise")
TRAINING_SETTINGS["voice"] = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" ) TRAINING_SETTINGS["voice"] = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )