forked from mrq/ai-voice-cloning
cleanup, metrics are grabbed for vall-e trainer
This commit is contained in:
parent
1b72d0bba0
commit
249c6019af
|
@ -83,47 +83,47 @@
|
|||
"x": 37,
|
||||
"y": 38,
|
||||
"z": 39,
|
||||
"d͡": 41,
|
||||
"t͡": 42,
|
||||
"|": 43,
|
||||
"æ": 44,
|
||||
"ð": 45,
|
||||
"ŋ": 46,
|
||||
"ɑ": 47,
|
||||
"ɔ": 48,
|
||||
"ə": 49,
|
||||
"ɚ": 50,
|
||||
"ɛ": 51,
|
||||
"ɡ": 52,
|
||||
"ɪ": 53,
|
||||
"ɹ": 54,
|
||||
"ʃ": 55,
|
||||
"ʊ": 56,
|
||||
"ʌ": 57,
|
||||
"ʒ": 58,
|
||||
"θ": 59,
|
||||
"ɐ": 60,
|
||||
"ɜ": 61,
|
||||
"ᵻ": 62,
|
||||
"ɾ": 63,
|
||||
"n\u0329": 64,
|
||||
"ː": 65,
|
||||
"ˈ": 66,
|
||||
"ˌ": 67,
|
||||
"ʔ": 68,
|
||||
"d͡ʒ": 69,
|
||||
"aɪ": 70,
|
||||
"aʊ": 71,
|
||||
"eɪ": 72,
|
||||
"oʊ": 73,
|
||||
"t͡ʃ": 74,
|
||||
"ɔɪ": 75,
|
||||
"ɔː": 76,
|
||||
"uː": 77,
|
||||
"iː": 78,
|
||||
"ɑː": 79,
|
||||
"oː": 80,
|
||||
"ɜː": 81
|
||||
"d͡": 40,
|
||||
"t͡": 41,
|
||||
"|": 42,
|
||||
"æ": 43,
|
||||
"ð": 44,
|
||||
"ŋ": 45,
|
||||
"ɑ": 46,
|
||||
"ɔ": 47,
|
||||
"ə": 48,
|
||||
"ɚ": 49,
|
||||
"ɛ": 50,
|
||||
"ɡ": 51,
|
||||
"ɪ": 52,
|
||||
"ɹ": 53,
|
||||
"ʃ": 54,
|
||||
"ʊ": 55,
|
||||
"ʌ": 56,
|
||||
"ʒ": 57,
|
||||
"θ": 58,
|
||||
"ɐ": 59,
|
||||
"ɜ": 60,
|
||||
"ᵻ": 61,
|
||||
"ɾ": 62,
|
||||
"n\u0329": 63,
|
||||
"ː": 64,
|
||||
"ˈ": 65,
|
||||
"ˌ": 66,
|
||||
"ʔ": 67,
|
||||
"d͡ʒ": 68,
|
||||
"aɪ": 69,
|
||||
"aʊ": 70,
|
||||
"eɪ": 71,
|
||||
"oʊ": 72,
|
||||
"t͡ʃ": 73,
|
||||
"ɔɪ": 74,
|
||||
"ɔː": 75,
|
||||
"uː": 76,
|
||||
"iː": 77,
|
||||
"ɑː": 78,
|
||||
"oː": 79,
|
||||
"ɜː": 80
|
||||
},
|
||||
"merges": [
|
||||
"a ɪ",
|
||||
|
|
194
src/utils.py
194
src/utils.py
|
@ -10,6 +10,7 @@ if 'TRANSFORMERS_CACHE' not in os.environ:
|
|||
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import json
|
||||
import base64
|
||||
import re
|
||||
|
@ -42,11 +43,6 @@ from whisper.normalizers.english import EnglishTextNormalizer
|
|||
from whisper.normalizers.basic import BasicTextNormalizer
|
||||
from whisper.tokenizer import LANGUAGES
|
||||
|
||||
try:
|
||||
from phonemizer import phonemize as phonemizer
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||
|
||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
||||
|
@ -340,6 +336,9 @@ def generate(**kwargs):
|
|||
|
||||
INFERENCING = True
|
||||
for line, cut_text in enumerate(texts):
|
||||
if should_phonemize():
|
||||
cut_text = phonemizer( cut_text )
|
||||
|
||||
if parameters['emotion'] == "Custom":
|
||||
if parameters['prompt'] and parameters['prompt'].strip() != "":
|
||||
cut_text = f"[{parameters['prompt']},] {cut_text}"
|
||||
|
@ -636,46 +635,31 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
|||
# superfluous, but it cleans up some things
|
||||
class TrainingState():
|
||||
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True):
|
||||
# parse config to get its iteration
|
||||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
||||
|
||||
self.killed = False
|
||||
|
||||
self.it = 0
|
||||
self.step = 0
|
||||
self.training_dir = os.path.dirname(config_path)
|
||||
with open(config_path, 'r') as file:
|
||||
self.yaml_config = yaml.safe_load(file)
|
||||
|
||||
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
|
||||
self.dataset_dir = f"{self.training_dir}/finetune/"
|
||||
self.dataset_path = f"{self.training_dir}/train.txt"
|
||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||
self.dataset_size = len(f.readlines())
|
||||
|
||||
self.batch_size = self.json_config["batch_size"]
|
||||
self.save_rate = self.json_config["save_rate"]
|
||||
|
||||
self.epoch = 0
|
||||
self.epochs = self.json_config["epochs"]
|
||||
self.it = 0
|
||||
self.its = calc_iterations( self.epochs, self.dataset_size, self.batch_size )
|
||||
self.step = 0
|
||||
self.steps = int(self.its / self.dataset_size)
|
||||
self.checkpoint = 0
|
||||
self.checkpoints = int((self.its - self.it) / self.save_rate)
|
||||
|
||||
if args.tts_backend == "tortoise":
|
||||
gpus = self.config["gpus"]
|
||||
|
||||
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||
self.dataset_path = self.config['datasets']['train']['path']
|
||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||
self.dataset_size = len(f.readlines())
|
||||
|
||||
self.its = self.config['train']['niter']
|
||||
self.steps = 1
|
||||
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
||||
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
||||
elif args.tts_backend == "vall-e":
|
||||
self.batch_size = self.config['batch_size']
|
||||
self.dataset_dir = f".{self.config['data_root']}/finetune/"
|
||||
self.dataset_path = f"{self.config['data_root']}/train.txt"
|
||||
|
||||
self.its = 1
|
||||
self.steps = 1
|
||||
self.epochs = 1
|
||||
self.checkpoints = 1
|
||||
|
||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||
self.dataset_size = len(f.readlines())
|
||||
|
||||
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
|
||||
gpus = self.json_config['gpus']
|
||||
self.gpus = self.json_config['gpus']
|
||||
|
||||
self.buffer = []
|
||||
|
||||
|
@ -706,12 +690,15 @@ class TrainingState():
|
|||
'loss': "",
|
||||
}
|
||||
|
||||
self.buffer_json = None
|
||||
self.json_buffer = []
|
||||
|
||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||
|
||||
if keep_x_past_checkpoints > 0:
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
if start:
|
||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||
self.spawn_process(config_path=config_path, gpus=self.gpus)
|
||||
|
||||
def spawn_process(self, config_path, gpus=1):
|
||||
if args.tts_backend == "vall-e":
|
||||
|
@ -771,6 +758,7 @@ class TrainingState():
|
|||
if 'lr' in self.info:
|
||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
||||
|
||||
if args.tts_backend == "tortoise":
|
||||
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
|
@ -779,6 +767,10 @@ class TrainingState():
|
|||
self.losses.append( self.statistics['loss'][-1] )
|
||||
else:
|
||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||
else:
|
||||
k = "loss"
|
||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
return data
|
||||
|
||||
|
@ -916,18 +908,62 @@ class TrainingState():
|
|||
print("Removing", path)
|
||||
os.remove(path)
|
||||
|
||||
def parse_valle_metrics(self, data):
|
||||
res = {}
|
||||
res['mode'] = "training"
|
||||
res['loss'] = data['model.loss']
|
||||
res['lr'] = data['model.lr']
|
||||
res['it'] = data['global_step']
|
||||
res['step'] = res['it'] % self.dataset_size
|
||||
res['steps'] = self.steps
|
||||
res['epoch'] = int(res['it'] / self.dataset_size)
|
||||
res['iteration_rate'] = data['elapsed_time']
|
||||
return res
|
||||
|
||||
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
||||
self.buffer.append(f'{line}')
|
||||
|
||||
should_return = False
|
||||
data = None
|
||||
percent = 0
|
||||
message = None
|
||||
should_return = False
|
||||
|
||||
if line.find('Finished training') >= 0:
|
||||
MESSAGE_START = 'Start training from epoch'
|
||||
MESSAGE_FINSIHED = 'Finished training'
|
||||
MESSAGE_SAVING = 'INFO: Saving models and training states.'
|
||||
|
||||
MESSAGE_METRICS_TRAINING = 'INFO: Training Metrics:'
|
||||
MESSAGE_METRICS_VALIDATION = 'INFO: Validation Metrics:'
|
||||
|
||||
if args.tts_backend == "vall-e":
|
||||
|
||||
if self.buffer_json:
|
||||
self.json_buffer.append(line)
|
||||
|
||||
if line.find("{") == 0 and not self.buffer_json:
|
||||
self.buffer_json = True
|
||||
self.json_buffer = [line]
|
||||
if line.find("}") == 0 and self.buffer_json:
|
||||
try:
|
||||
data = json.loads("\n".join(self.json_buffer))
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
if data and 'model.loss' in data:
|
||||
self.training_started = True
|
||||
data = self.parse_valle_metrics( data )
|
||||
print("Training JSON:", data)
|
||||
else:
|
||||
data = None
|
||||
|
||||
self.buffer_json = None
|
||||
self.json_buffer = []
|
||||
|
||||
if line.find(MESSAGE_FINSIHED) >= 0:
|
||||
self.killed = True
|
||||
# rip out iteration info
|
||||
elif not self.training_started:
|
||||
if line.find('Start training from epoch') >= 0:
|
||||
if line.find(MESSAGE_START) >= 0:
|
||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||
|
||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||
|
@ -937,24 +973,23 @@ class TrainingState():
|
|||
if match and len(match) > 0:
|
||||
self.it = int(match[0].replace(",", ""))
|
||||
|
||||
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
|
||||
self.checkpoints = int((self.its - self.it) / self.save_rate)
|
||||
|
||||
self.load_statistics()
|
||||
|
||||
should_return = True
|
||||
else:
|
||||
data = None
|
||||
if line.find('INFO: Saving models and training states.') >= 0:
|
||||
if line.find(MESSAGE_SAVING) >= 0:
|
||||
self.checkpoint += 1
|
||||
message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..."
|
||||
percent = self.checkpoint / self.checkpoints
|
||||
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
elif line.find('INFO: Training Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||
elif line.find(MESSAGE_METRICS_TRAINING) >= 0:
|
||||
data = json.loads(line.split(MESSAGE_METRICS_TRAINING)[-1])
|
||||
data['mode'] = "training"
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||
elif line.find(MESSAGE_METRICS_VALIDATION) >= 0:
|
||||
data = json.loads(line.split(MESSAGE_METRICS_VALIDATION)[-1])
|
||||
data['mode'] = "validation"
|
||||
|
||||
if data is not None:
|
||||
|
@ -1278,7 +1313,7 @@ def phonemize_txt_file( path ):
|
|||
audio = split[0]
|
||||
text = split[2]
|
||||
|
||||
phonemes = phonemizer( text, preserve_punctuation=True, strip=True )
|
||||
phonemes = phonemizer( text )
|
||||
reparsed.append(f'{audio}|{phonemes}')
|
||||
f.write(f'\n{audio}|{phonemes}')
|
||||
|
||||
|
@ -1321,6 +1356,21 @@ def create_dataset_json( path ):
|
|||
with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(data, indent="\t"))
|
||||
|
||||
def phonemizer( text, language="en-us" ):
|
||||
from phonemizer import phonemize
|
||||
if language == "english":
|
||||
language = "en-us"
|
||||
return phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend )
|
||||
|
||||
def should_phonemize():
|
||||
try:
|
||||
from phonemizer import phonemize
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
return False
|
||||
|
||||
return args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json"
|
||||
|
||||
def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ):
|
||||
indir = f'./training/{voice}/'
|
||||
infile = f'{indir}/whisper.json'
|
||||
|
@ -1332,7 +1382,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
errored = 0
|
||||
messages = []
|
||||
normalize = True
|
||||
phonemize = args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json"
|
||||
phonemize = should_phonemize()
|
||||
lines = { 'training': [], 'validation': [] }
|
||||
segments = {}
|
||||
|
||||
|
@ -1374,7 +1424,12 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
if use_segment and not use_segments:
|
||||
exists = True
|
||||
for segment in result['segments']:
|
||||
if os.path.exists(filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")):
|
||||
duration = segment['end'] - segment['start']
|
||||
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
||||
continue
|
||||
|
||||
path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
||||
if os.path.exists(path):
|
||||
continue
|
||||
exists = False
|
||||
break
|
||||
|
@ -1396,6 +1451,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
}
|
||||
else:
|
||||
for segment in result['segments']:
|
||||
duration = segment['end'] - segment['start']
|
||||
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
||||
continue
|
||||
|
||||
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
|
||||
'text': segment['text'],
|
||||
'language': language,
|
||||
|
@ -1412,7 +1471,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
normalizer = result['normalizer']
|
||||
phonemes = result['phonemes']
|
||||
if phonemize and phonemes is None:
|
||||
phonemes = phonemizer( text, language=language if language != "english" else "en-us", strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend )
|
||||
phonemes = phonemizer( text, language=language )
|
||||
if phonemize:
|
||||
text = phonemes
|
||||
|
||||
|
@ -1456,7 +1515,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
print("Quantized:", file)
|
||||
|
||||
tokens = tokenize_text(text, stringed=False, skip_specials=True)
|
||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join( tokens ).replace(" \u02C8", "\u02C8"))
|
||||
tokenized = " ".join( tokens )
|
||||
tokenized = tokenized.replace(" \u02C8", "\u02C8")
|
||||
tokenized = tokenized.replace(" \u02CC", "\u02CC")
|
||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(tokenized)
|
||||
|
||||
training_joined = "\n".join(lines['training'])
|
||||
validation_joined = "\n".join(lines['validation'])
|
||||
|
@ -1471,8 +1533,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
return "\n".join(messages)
|
||||
|
||||
def calc_iterations( epochs, lines, batch_size ):
|
||||
iterations = int(epochs * lines / float(batch_size))
|
||||
return iterations
|
||||
return int(math.ceil(epochs * math.ceil(lines / batch_size)))
|
||||
|
||||
def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ):
|
||||
return [int(iterations * d) for d in schedule]
|
||||
|
@ -1580,7 +1641,9 @@ def optimize_training_settings( **kwargs ):
|
|||
if not os.path.exists(get_halfp_model_path()):
|
||||
convert_to_halfp()
|
||||
|
||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({int(iterations / settings['epochs'])} steps per epoch)")
|
||||
settings['steps'] = int(iterations / settings['epochs'])
|
||||
|
||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({settings['steps']}) steps per epoch)")
|
||||
|
||||
return settings, messages
|
||||
|
||||
|
@ -1589,6 +1652,7 @@ def save_training_settings( **kwargs ):
|
|||
settings = {}
|
||||
settings.update(kwargs)
|
||||
|
||||
|
||||
outjson = f'./training/{settings["voice"]}/train.json'
|
||||
with open(outjson, 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(settings, indent='\t') )
|
||||
|
@ -1599,6 +1663,8 @@ def save_training_settings( **kwargs ):
|
|||
with open(settings['dataset_path'], 'r', encoding="utf-8") as f:
|
||||
lines = len(f.readlines())
|
||||
|
||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||
|
||||
if not settings['source_model'] or settings['source_model'] == "auto":
|
||||
settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth"
|
||||
|
||||
|
@ -1606,7 +1672,6 @@ def save_training_settings( **kwargs ):
|
|||
if not os.path.exists(get_halfp_model_path()):
|
||||
convert_to_halfp()
|
||||
|
||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
|
||||
|
||||
iterations_per_epoch = settings['iterations'] / settings['epochs']
|
||||
|
@ -1622,15 +1687,14 @@ def save_training_settings( **kwargs ):
|
|||
if settings['validation_rate'] < 1:
|
||||
settings['validation_rate'] = 1
|
||||
"""
|
||||
|
||||
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||
|
||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||
"""
|
||||
if settings['iterations'] % settings['save_rate'] != 0:
|
||||
adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate']
|
||||
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}")
|
||||
settings['iterations'] = adjustment
|
||||
"""
|
||||
|
||||
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||
if not os.path.exists(settings['validation_path']):
|
||||
settings['validation_enabled'] = False
|
||||
messages.append("Validation not found, disabling validation...")
|
||||
|
@ -1833,7 +1897,7 @@ def tokenize_text( text, stringed=True, skip_specials=False ):
|
|||
tts.tokenizer
|
||||
|
||||
encoded = tokenizer.encode(text)
|
||||
decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=specials).split(" ")
|
||||
decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=skip_specials).split(" ")
|
||||
|
||||
if stringed:
|
||||
return "\n".join([ str(encoded), str(decoded) ])
|
||||
|
|
16
src/webui.py
16
src/webui.py
|
@ -467,12 +467,12 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
|
||||
with gr.Row():
|
||||
with gr.Row(visible=args.tts_backend=="tortoise"):
|
||||
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
|
||||
TRAINING_SETTINGS["mel_lr_weight"] = gr.Slider(label="Mel LR Ratio", value=1.00, minimum=0, maximum=1)
|
||||
TRAINING_SETTINGS["text_lr_weight"] = gr.Slider(label="Text LR Ratio", value=0.01, minimum=0, maximum=1)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(visible=args.tts_backend=="tortoise"):
|
||||
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
|
||||
TRAINING_SETTINGS["learning_rate_scheme"] = gr.Radio(lr_schemes, label="Learning Rate Scheme", value=lr_schemes[0], type="value")
|
||||
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(LEARNING_RATE_SCHEDULE), visible=True)
|
||||
|
@ -488,22 +488,22 @@ def setup_gradio():
|
|||
)
|
||||
with gr.Row():
|
||||
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
|
||||
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
|
||||
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0, visible=args.tts_backend=="tortoise")
|
||||
with gr.Row():
|
||||
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
|
||||
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
||||
TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
||||
TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp, visible=args.tts_backend=="tortoise")
|
||||
TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb, visible=args.tts_backend=="tortoise")
|
||||
TRAINING_SETTINGS["validation_enabled"] = gr.Checkbox(label="Validation Enabled", value=False)
|
||||
|
||||
with gr.Row():
|
||||
TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0)
|
||||
TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0, visible=args.tts_backend=="tortoise")
|
||||
TRAINING_SETTINGS["gpus"] = gr.Number(label="GPUs", value=get_device_count(), precision=0)
|
||||
|
||||
TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
|
||||
TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/finetune/training_state/${last_state}.state")
|
||||
TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0], visible=args.tts_backend=="tortoise" )
|
||||
TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/finetune/training_state/${last_state}.state", visible=args.tts_backend=="tortoise")
|
||||
|
||||
TRAINING_SETTINGS["voice"] = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user