forked from mrq/ai-voice-cloning
cleanup, metrics are grabbed for vall-e trainer
This commit is contained in:
parent
1b72d0bba0
commit
249c6019af
|
@ -83,47 +83,47 @@
|
||||||
"x": 37,
|
"x": 37,
|
||||||
"y": 38,
|
"y": 38,
|
||||||
"z": 39,
|
"z": 39,
|
||||||
"d͡": 41,
|
"d͡": 40,
|
||||||
"t͡": 42,
|
"t͡": 41,
|
||||||
"|": 43,
|
"|": 42,
|
||||||
"æ": 44,
|
"æ": 43,
|
||||||
"ð": 45,
|
"ð": 44,
|
||||||
"ŋ": 46,
|
"ŋ": 45,
|
||||||
"ɑ": 47,
|
"ɑ": 46,
|
||||||
"ɔ": 48,
|
"ɔ": 47,
|
||||||
"ə": 49,
|
"ə": 48,
|
||||||
"ɚ": 50,
|
"ɚ": 49,
|
||||||
"ɛ": 51,
|
"ɛ": 50,
|
||||||
"ɡ": 52,
|
"ɡ": 51,
|
||||||
"ɪ": 53,
|
"ɪ": 52,
|
||||||
"ɹ": 54,
|
"ɹ": 53,
|
||||||
"ʃ": 55,
|
"ʃ": 54,
|
||||||
"ʊ": 56,
|
"ʊ": 55,
|
||||||
"ʌ": 57,
|
"ʌ": 56,
|
||||||
"ʒ": 58,
|
"ʒ": 57,
|
||||||
"θ": 59,
|
"θ": 58,
|
||||||
"ɐ": 60,
|
"ɐ": 59,
|
||||||
"ɜ": 61,
|
"ɜ": 60,
|
||||||
"ᵻ": 62,
|
"ᵻ": 61,
|
||||||
"ɾ": 63,
|
"ɾ": 62,
|
||||||
"n\u0329": 64,
|
"n\u0329": 63,
|
||||||
"ː": 65,
|
"ː": 64,
|
||||||
"ˈ": 66,
|
"ˈ": 65,
|
||||||
"ˌ": 67,
|
"ˌ": 66,
|
||||||
"ʔ": 68,
|
"ʔ": 67,
|
||||||
"d͡ʒ": 69,
|
"d͡ʒ": 68,
|
||||||
"aɪ": 70,
|
"aɪ": 69,
|
||||||
"aʊ": 71,
|
"aʊ": 70,
|
||||||
"eɪ": 72,
|
"eɪ": 71,
|
||||||
"oʊ": 73,
|
"oʊ": 72,
|
||||||
"t͡ʃ": 74,
|
"t͡ʃ": 73,
|
||||||
"ɔɪ": 75,
|
"ɔɪ": 74,
|
||||||
"ɔː": 76,
|
"ɔː": 75,
|
||||||
"uː": 77,
|
"uː": 76,
|
||||||
"iː": 78,
|
"iː": 77,
|
||||||
"ɑː": 79,
|
"ɑː": 78,
|
||||||
"oː": 80,
|
"oː": 79,
|
||||||
"ɜː": 81
|
"ɜː": 80
|
||||||
},
|
},
|
||||||
"merges": [
|
"merges": [
|
||||||
"a ɪ",
|
"a ɪ",
|
||||||
|
|
194
src/utils.py
194
src/utils.py
|
@ -10,6 +10,7 @@ if 'TRANSFORMERS_CACHE' not in os.environ:
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
import math
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
|
@ -42,11 +43,6 @@ from whisper.normalizers.english import EnglishTextNormalizer
|
||||||
from whisper.normalizers.basic import BasicTextNormalizer
|
from whisper.normalizers.basic import BasicTextNormalizer
|
||||||
from whisper.tokenizer import LANGUAGES
|
from whisper.tokenizer import LANGUAGES
|
||||||
|
|
||||||
try:
|
|
||||||
from phonemizer import phonemize as phonemizer
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
|
|
||||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
||||||
|
@ -340,6 +336,9 @@ def generate(**kwargs):
|
||||||
|
|
||||||
INFERENCING = True
|
INFERENCING = True
|
||||||
for line, cut_text in enumerate(texts):
|
for line, cut_text in enumerate(texts):
|
||||||
|
if should_phonemize():
|
||||||
|
cut_text = phonemizer( cut_text )
|
||||||
|
|
||||||
if parameters['emotion'] == "Custom":
|
if parameters['emotion'] == "Custom":
|
||||||
if parameters['prompt'] and parameters['prompt'].strip() != "":
|
if parameters['prompt'] and parameters['prompt'].strip() != "":
|
||||||
cut_text = f"[{parameters['prompt']},] {cut_text}"
|
cut_text = f"[{parameters['prompt']},] {cut_text}"
|
||||||
|
@ -636,46 +635,31 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True):
|
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True):
|
||||||
# parse config to get its iteration
|
|
||||||
with open(config_path, 'r') as file:
|
|
||||||
self.config = yaml.safe_load(file)
|
|
||||||
|
|
||||||
|
|
||||||
self.killed = False
|
self.killed = False
|
||||||
|
|
||||||
self.it = 0
|
self.training_dir = os.path.dirname(config_path)
|
||||||
self.step = 0
|
with open(config_path, 'r') as file:
|
||||||
|
self.yaml_config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
|
||||||
|
self.dataset_dir = f"{self.training_dir}/finetune/"
|
||||||
|
self.dataset_path = f"{self.training_dir}/train.txt"
|
||||||
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
self.dataset_size = len(f.readlines())
|
||||||
|
|
||||||
|
self.batch_size = self.json_config["batch_size"]
|
||||||
|
self.save_rate = self.json_config["save_rate"]
|
||||||
|
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
|
self.epochs = self.json_config["epochs"]
|
||||||
|
self.it = 0
|
||||||
|
self.its = calc_iterations( self.epochs, self.dataset_size, self.batch_size )
|
||||||
|
self.step = 0
|
||||||
|
self.steps = int(self.its / self.dataset_size)
|
||||||
self.checkpoint = 0
|
self.checkpoint = 0
|
||||||
|
self.checkpoints = int((self.its - self.it) / self.save_rate)
|
||||||
|
|
||||||
if args.tts_backend == "tortoise":
|
self.gpus = self.json_config['gpus']
|
||||||
gpus = self.config["gpus"]
|
|
||||||
|
|
||||||
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
|
||||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
|
||||||
self.dataset_path = self.config['datasets']['train']['path']
|
|
||||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
|
||||||
self.dataset_size = len(f.readlines())
|
|
||||||
|
|
||||||
self.its = self.config['train']['niter']
|
|
||||||
self.steps = 1
|
|
||||||
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
|
||||||
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
|
||||||
elif args.tts_backend == "vall-e":
|
|
||||||
self.batch_size = self.config['batch_size']
|
|
||||||
self.dataset_dir = f".{self.config['data_root']}/finetune/"
|
|
||||||
self.dataset_path = f"{self.config['data_root']}/train.txt"
|
|
||||||
|
|
||||||
self.its = 1
|
|
||||||
self.steps = 1
|
|
||||||
self.epochs = 1
|
|
||||||
self.checkpoints = 1
|
|
||||||
|
|
||||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
|
||||||
self.dataset_size = len(f.readlines())
|
|
||||||
|
|
||||||
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
|
|
||||||
gpus = self.json_config['gpus']
|
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
|
@ -706,12 +690,15 @@ class TrainingState():
|
||||||
'loss': "",
|
'loss': "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.buffer_json = None
|
||||||
|
self.json_buffer = []
|
||||||
|
|
||||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||||
|
|
||||||
if keep_x_past_checkpoints > 0:
|
if keep_x_past_checkpoints > 0:
|
||||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||||
if start:
|
if start:
|
||||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
self.spawn_process(config_path=config_path, gpus=self.gpus)
|
||||||
|
|
||||||
def spawn_process(self, config_path, gpus=1):
|
def spawn_process(self, config_path, gpus=1):
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend == "vall-e":
|
||||||
|
@ -771,6 +758,7 @@ class TrainingState():
|
||||||
if 'lr' in self.info:
|
if 'lr' in self.info:
|
||||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
||||||
|
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
continue
|
continue
|
||||||
|
@ -779,6 +767,10 @@ class TrainingState():
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
else:
|
else:
|
||||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||||
|
else:
|
||||||
|
k = "loss"
|
||||||
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||||
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -916,18 +908,62 @@ class TrainingState():
|
||||||
print("Removing", path)
|
print("Removing", path)
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
|
def parse_valle_metrics(self, data):
|
||||||
|
res = {}
|
||||||
|
res['mode'] = "training"
|
||||||
|
res['loss'] = data['model.loss']
|
||||||
|
res['lr'] = data['model.lr']
|
||||||
|
res['it'] = data['global_step']
|
||||||
|
res['step'] = res['it'] % self.dataset_size
|
||||||
|
res['steps'] = self.steps
|
||||||
|
res['epoch'] = int(res['it'] / self.dataset_size)
|
||||||
|
res['iteration_rate'] = data['elapsed_time']
|
||||||
|
return res
|
||||||
|
|
||||||
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
||||||
self.buffer.append(f'{line}')
|
self.buffer.append(f'{line}')
|
||||||
|
|
||||||
should_return = False
|
data = None
|
||||||
percent = 0
|
percent = 0
|
||||||
message = None
|
message = None
|
||||||
|
should_return = False
|
||||||
|
|
||||||
if line.find('Finished training') >= 0:
|
MESSAGE_START = 'Start training from epoch'
|
||||||
|
MESSAGE_FINSIHED = 'Finished training'
|
||||||
|
MESSAGE_SAVING = 'INFO: Saving models and training states.'
|
||||||
|
|
||||||
|
MESSAGE_METRICS_TRAINING = 'INFO: Training Metrics:'
|
||||||
|
MESSAGE_METRICS_VALIDATION = 'INFO: Validation Metrics:'
|
||||||
|
|
||||||
|
if args.tts_backend == "vall-e":
|
||||||
|
|
||||||
|
if self.buffer_json:
|
||||||
|
self.json_buffer.append(line)
|
||||||
|
|
||||||
|
if line.find("{") == 0 and not self.buffer_json:
|
||||||
|
self.buffer_json = True
|
||||||
|
self.json_buffer = [line]
|
||||||
|
if line.find("}") == 0 and self.buffer_json:
|
||||||
|
try:
|
||||||
|
data = json.loads("\n".join(self.json_buffer))
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
|
||||||
|
if data and 'model.loss' in data:
|
||||||
|
self.training_started = True
|
||||||
|
data = self.parse_valle_metrics( data )
|
||||||
|
print("Training JSON:", data)
|
||||||
|
else:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
self.buffer_json = None
|
||||||
|
self.json_buffer = []
|
||||||
|
|
||||||
|
if line.find(MESSAGE_FINSIHED) >= 0:
|
||||||
self.killed = True
|
self.killed = True
|
||||||
# rip out iteration info
|
# rip out iteration info
|
||||||
elif not self.training_started:
|
elif not self.training_started:
|
||||||
if line.find('Start training from epoch') >= 0:
|
if line.find(MESSAGE_START) >= 0:
|
||||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||||
|
|
||||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||||
|
@ -937,24 +973,23 @@ class TrainingState():
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
self.it = int(match[0].replace(",", ""))
|
self.it = int(match[0].replace(",", ""))
|
||||||
|
|
||||||
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
|
self.checkpoints = int((self.its - self.it) / self.save_rate)
|
||||||
|
|
||||||
self.load_statistics()
|
self.load_statistics()
|
||||||
|
|
||||||
should_return = True
|
should_return = True
|
||||||
else:
|
else:
|
||||||
data = None
|
if line.find(MESSAGE_SAVING) >= 0:
|
||||||
if line.find('INFO: Saving models and training states.') >= 0:
|
|
||||||
self.checkpoint += 1
|
self.checkpoint += 1
|
||||||
message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..."
|
message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..."
|
||||||
percent = self.checkpoint / self.checkpoints
|
percent = self.checkpoint / self.checkpoints
|
||||||
|
|
||||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||||
elif line.find('INFO: Training Metrics:') >= 0:
|
elif line.find(MESSAGE_METRICS_TRAINING) >= 0:
|
||||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
data = json.loads(line.split(MESSAGE_METRICS_TRAINING)[-1])
|
||||||
data['mode'] = "training"
|
data['mode'] = "training"
|
||||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
elif line.find(MESSAGE_METRICS_VALIDATION) >= 0:
|
||||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
data = json.loads(line.split(MESSAGE_METRICS_VALIDATION)[-1])
|
||||||
data['mode'] = "validation"
|
data['mode'] = "validation"
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
|
@ -1278,7 +1313,7 @@ def phonemize_txt_file( path ):
|
||||||
audio = split[0]
|
audio = split[0]
|
||||||
text = split[2]
|
text = split[2]
|
||||||
|
|
||||||
phonemes = phonemizer( text, preserve_punctuation=True, strip=True )
|
phonemes = phonemizer( text )
|
||||||
reparsed.append(f'{audio}|{phonemes}')
|
reparsed.append(f'{audio}|{phonemes}')
|
||||||
f.write(f'\n{audio}|{phonemes}')
|
f.write(f'\n{audio}|{phonemes}')
|
||||||
|
|
||||||
|
@ -1321,6 +1356,21 @@ def create_dataset_json( path ):
|
||||||
with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f:
|
with open(path.replace(".txt", ".json"), 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps(data, indent="\t"))
|
f.write(json.dumps(data, indent="\t"))
|
||||||
|
|
||||||
|
def phonemizer( text, language="en-us" ):
|
||||||
|
from phonemizer import phonemize
|
||||||
|
if language == "english":
|
||||||
|
language = "en-us"
|
||||||
|
return phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend )
|
||||||
|
|
||||||
|
def should_phonemize():
|
||||||
|
try:
|
||||||
|
from phonemizer import phonemize
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
return args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json"
|
||||||
|
|
||||||
def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ):
|
def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, progress=gr.Progress() ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
|
@ -1332,7 +1382,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
errored = 0
|
errored = 0
|
||||||
messages = []
|
messages = []
|
||||||
normalize = True
|
normalize = True
|
||||||
phonemize = args.tokenizer_json is not None and args.tokenizer_json[-8:] == "ipa.json"
|
phonemize = should_phonemize()
|
||||||
lines = { 'training': [], 'validation': [] }
|
lines = { 'training': [], 'validation': [] }
|
||||||
segments = {}
|
segments = {}
|
||||||
|
|
||||||
|
@ -1374,7 +1424,12 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
if use_segment and not use_segments:
|
if use_segment and not use_segments:
|
||||||
exists = True
|
exists = True
|
||||||
for segment in result['segments']:
|
for segment in result['segments']:
|
||||||
if os.path.exists(filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")):
|
duration = segment['end'] - segment['start']
|
||||||
|
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
||||||
|
continue
|
||||||
|
|
||||||
|
path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
||||||
|
if os.path.exists(path):
|
||||||
continue
|
continue
|
||||||
exists = False
|
exists = False
|
||||||
break
|
break
|
||||||
|
@ -1396,6 +1451,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
for segment in result['segments']:
|
for segment in result['segments']:
|
||||||
|
duration = segment['end'] - segment['start']
|
||||||
|
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
||||||
|
continue
|
||||||
|
|
||||||
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
|
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
|
||||||
'text': segment['text'],
|
'text': segment['text'],
|
||||||
'language': language,
|
'language': language,
|
||||||
|
@ -1412,7 +1471,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
normalizer = result['normalizer']
|
normalizer = result['normalizer']
|
||||||
phonemes = result['phonemes']
|
phonemes = result['phonemes']
|
||||||
if phonemize and phonemes is None:
|
if phonemize and phonemes is None:
|
||||||
phonemes = phonemizer( text, language=language if language != "english" else "en-us", strip=True, preserve_punctuation=True, with_stress=True, backend=args.phonemizer_backend )
|
phonemes = phonemizer( text, language=language )
|
||||||
if phonemize:
|
if phonemize:
|
||||||
text = phonemes
|
text = phonemes
|
||||||
|
|
||||||
|
@ -1456,7 +1515,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
print("Quantized:", file)
|
print("Quantized:", file)
|
||||||
|
|
||||||
tokens = tokenize_text(text, stringed=False, skip_specials=True)
|
tokens = tokenize_text(text, stringed=False, skip_specials=True)
|
||||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join( tokens ).replace(" \u02C8", "\u02C8"))
|
tokenized = " ".join( tokens )
|
||||||
|
tokenized = tokenized.replace(" \u02C8", "\u02C8")
|
||||||
|
tokenized = tokenized.replace(" \u02CC", "\u02CC")
|
||||||
|
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(tokenized)
|
||||||
|
|
||||||
training_joined = "\n".join(lines['training'])
|
training_joined = "\n".join(lines['training'])
|
||||||
validation_joined = "\n".join(lines['validation'])
|
validation_joined = "\n".join(lines['validation'])
|
||||||
|
@ -1471,8 +1533,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def calc_iterations( epochs, lines, batch_size ):
|
def calc_iterations( epochs, lines, batch_size ):
|
||||||
iterations = int(epochs * lines / float(batch_size))
|
return int(math.ceil(epochs * math.ceil(lines / batch_size)))
|
||||||
return iterations
|
|
||||||
|
|
||||||
def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ):
|
def schedule_learning_rate( iterations, schedule=LEARNING_RATE_SCHEDULE ):
|
||||||
return [int(iterations * d) for d in schedule]
|
return [int(iterations * d) for d in schedule]
|
||||||
|
@ -1580,7 +1641,9 @@ def optimize_training_settings( **kwargs ):
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
if not os.path.exists(get_halfp_model_path()):
|
||||||
convert_to_halfp()
|
convert_to_halfp()
|
||||||
|
|
||||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({int(iterations / settings['epochs'])} steps per epoch)")
|
settings['steps'] = int(iterations / settings['epochs'])
|
||||||
|
|
||||||
|
messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({settings['steps']}) steps per epoch)")
|
||||||
|
|
||||||
return settings, messages
|
return settings, messages
|
||||||
|
|
||||||
|
@ -1589,6 +1652,7 @@ def save_training_settings( **kwargs ):
|
||||||
settings = {}
|
settings = {}
|
||||||
settings.update(kwargs)
|
settings.update(kwargs)
|
||||||
|
|
||||||
|
|
||||||
outjson = f'./training/{settings["voice"]}/train.json'
|
outjson = f'./training/{settings["voice"]}/train.json'
|
||||||
with open(outjson, 'w', encoding="utf-8") as f:
|
with open(outjson, 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(settings, indent='\t') )
|
f.write(json.dumps(settings, indent='\t') )
|
||||||
|
@ -1599,6 +1663,8 @@ def save_training_settings( **kwargs ):
|
||||||
with open(settings['dataset_path'], 'r', encoding="utf-8") as f:
|
with open(settings['dataset_path'], 'r', encoding="utf-8") as f:
|
||||||
lines = len(f.readlines())
|
lines = len(f.readlines())
|
||||||
|
|
||||||
|
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||||
|
|
||||||
if not settings['source_model'] or settings['source_model'] == "auto":
|
if not settings['source_model'] or settings['source_model'] == "auto":
|
||||||
settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth"
|
settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth"
|
||||||
|
|
||||||
|
@ -1606,7 +1672,6 @@ def save_training_settings( **kwargs ):
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
if not os.path.exists(get_halfp_model_path()):
|
||||||
convert_to_halfp()
|
convert_to_halfp()
|
||||||
|
|
||||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
|
||||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
|
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
|
||||||
|
|
||||||
iterations_per_epoch = settings['iterations'] / settings['epochs']
|
iterations_per_epoch = settings['iterations'] / settings['epochs']
|
||||||
|
@ -1622,15 +1687,14 @@ def save_training_settings( **kwargs ):
|
||||||
if settings['validation_rate'] < 1:
|
if settings['validation_rate'] < 1:
|
||||||
settings['validation_rate'] = 1
|
settings['validation_rate'] = 1
|
||||||
"""
|
"""
|
||||||
|
"""
|
||||||
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
|
||||||
|
|
||||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
|
||||||
if settings['iterations'] % settings['save_rate'] != 0:
|
if settings['iterations'] % settings['save_rate'] != 0:
|
||||||
adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate']
|
adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate']
|
||||||
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}")
|
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}")
|
||||||
settings['iterations'] = adjustment
|
settings['iterations'] = adjustment
|
||||||
|
"""
|
||||||
|
|
||||||
|
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||||
if not os.path.exists(settings['validation_path']):
|
if not os.path.exists(settings['validation_path']):
|
||||||
settings['validation_enabled'] = False
|
settings['validation_enabled'] = False
|
||||||
messages.append("Validation not found, disabling validation...")
|
messages.append("Validation not found, disabling validation...")
|
||||||
|
@ -1833,7 +1897,7 @@ def tokenize_text( text, stringed=True, skip_specials=False ):
|
||||||
tts.tokenizer
|
tts.tokenizer
|
||||||
|
|
||||||
encoded = tokenizer.encode(text)
|
encoded = tokenizer.encode(text)
|
||||||
decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=specials).split(" ")
|
decoded = tokenizer.tokenizer.decode(encoded, skip_special_tokens=skip_specials).split(" ")
|
||||||
|
|
||||||
if stringed:
|
if stringed:
|
||||||
return "\n".join([ str(encoded), str(decoded) ])
|
return "\n".join([ str(encoded), str(decoded) ])
|
||||||
|
|
16
src/webui.py
16
src/webui.py
|
@ -467,12 +467,12 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
|
TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
|
||||||
with gr.Row():
|
with gr.Row(visible=args.tts_backend=="tortoise"):
|
||||||
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
|
TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
|
||||||
TRAINING_SETTINGS["mel_lr_weight"] = gr.Slider(label="Mel LR Ratio", value=1.00, minimum=0, maximum=1)
|
TRAINING_SETTINGS["mel_lr_weight"] = gr.Slider(label="Mel LR Ratio", value=1.00, minimum=0, maximum=1)
|
||||||
TRAINING_SETTINGS["text_lr_weight"] = gr.Slider(label="Text LR Ratio", value=0.01, minimum=0, maximum=1)
|
TRAINING_SETTINGS["text_lr_weight"] = gr.Slider(label="Text LR Ratio", value=0.01, minimum=0, maximum=1)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row(visible=args.tts_backend=="tortoise"):
|
||||||
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
|
lr_schemes = list(LEARNING_RATE_SCHEMES.keys())
|
||||||
TRAINING_SETTINGS["learning_rate_scheme"] = gr.Radio(lr_schemes, label="Learning Rate Scheme", value=lr_schemes[0], type="value")
|
TRAINING_SETTINGS["learning_rate_scheme"] = gr.Radio(lr_schemes, label="Learning Rate Scheme", value=lr_schemes[0], type="value")
|
||||||
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(LEARNING_RATE_SCHEDULE), visible=True)
|
TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(LEARNING_RATE_SCHEDULE), visible=True)
|
||||||
|
@ -488,22 +488,22 @@ def setup_gradio():
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
|
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
|
||||||
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
|
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0, visible=args.tts_backend=="tortoise")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
|
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
|
||||||
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
|
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
|
TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp, visible=args.tts_backend=="tortoise")
|
||||||
TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
|
TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb, visible=args.tts_backend=="tortoise")
|
||||||
TRAINING_SETTINGS["validation_enabled"] = gr.Checkbox(label="Validation Enabled", value=False)
|
TRAINING_SETTINGS["validation_enabled"] = gr.Checkbox(label="Validation Enabled", value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0)
|
TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0, visible=args.tts_backend=="tortoise")
|
||||||
TRAINING_SETTINGS["gpus"] = gr.Number(label="GPUs", value=get_device_count(), precision=0)
|
TRAINING_SETTINGS["gpus"] = gr.Number(label="GPUs", value=get_device_count(), precision=0)
|
||||||
|
|
||||||
TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
|
TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0], visible=args.tts_backend=="tortoise" )
|
||||||
TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/finetune/training_state/${last_state}.state")
|
TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/finetune/training_state/${last_state}.state", visible=args.tts_backend=="tortoise")
|
||||||
|
|
||||||
TRAINING_SETTINGS["voice"] = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
|
TRAINING_SETTINGS["voice"] = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user