forked from mrq/ai-voice-cloning
tab to generate the training YAML
This commit is contained in:
parent
3a078df95e
commit
f8249aa826
21
README.md
21
README.md
|
@ -221,6 +221,27 @@ If you want to reuse its generation settings, simply click `Copy Settings`.
|
|||
|
||||
To import a voice, click `Import Voice`. Remember to click `Refresh Voice List` in the `Generate` panel afterwards, if it's a new voice.
|
||||
|
||||
### Training
|
||||
|
||||
This tab will contain a collection of sub-tabs pertaining to training.
|
||||
|
||||
#### Configuration
|
||||
|
||||
This will generate the YAML necessary to feed into training. For now, you can set:
|
||||
* `Batch Size`: size of batches for training, more batches = faster training, at the cost of higher VRAM. setting this to 1 will lead to problems
|
||||
* `Learning Rate`: how large changes to training will be made, lower values = better over the long term, while higher values will fry a model fast. For fine-tuning, the default *should* be fine, but in the future, a learning rate scheduler would be better (have a higher learning rate initially, then step it down over enough steps/epochs)
|
||||
* `Print Frequency`: how often to print (I assume)
|
||||
* `Save Frequency`: how often to save checkpoints
|
||||
* `Training Name`: name to save the configuration as, as well as the training script to create the folder under
|
||||
* `Dataset Name`: **!**TODO**!**: fill
|
||||
* `Dataset Path`: path to the input training text file. For LJSpeech-esque datasets, this is to a textfile formatted like:
|
||||
```
|
||||
wavs/LJ001-0001.wav|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition
|
||||
wavs/LJ001-0002.wav|in being comparatively modern.|in being comparatively modern.
|
||||
```
|
||||
* `Validation Name`: **!**TODO**!**: fill
|
||||
* `Validation Path`: path for the validation set, similar to the dataset. I'm not necessarily sure what to really use for this, so explicitly for testing, I just copied the training dataset text
|
||||
|
||||
### Settings
|
||||
|
||||
This tab (should) hold a bunch of other settings, from tunables that shouldn't be tampered with, to settings pertaining to the web UI itself.
|
||||
|
|
35
src/train.py
Executable file
35
src/train.py
Executable file
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
import argparse
|
||||
|
||||
from ..dlas.codes import *
|
||||
from ..dlas.codes.utils import util, options as option
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
if args.launcher != 'none':
|
||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
||||
if 'gpu_ids' in opt.keys():
|
||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||
trainer = Trainer()
|
||||
|
||||
#### distributed training settings
|
||||
if args.launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
trainer.rank = -1
|
||||
if len(opt['gpu_ids']) == 1:
|
||||
torch.cuda.set_device(opt['gpu_ids'][0])
|
||||
print('Disabled distributed training.')
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist('nccl')
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
||||
trainer.init(args.opt, opt, args.launcher)
|
||||
trainer.do_training()
|
252
src/utils.py
252
src/utils.py
|
@ -138,7 +138,7 @@ def generate(
|
|||
try:
|
||||
tts
|
||||
except NameError:
|
||||
raise gr.Error("TTS is still initializing...")
|
||||
raise Exception("TTS is still initializing...")
|
||||
|
||||
if voice != "microphone":
|
||||
voices = [voice]
|
||||
|
@ -147,7 +147,7 @@ def generate(
|
|||
|
||||
if voice == "microphone":
|
||||
if mic_audio is None:
|
||||
raise gr.Error("Please provide audio from mic when choosing `microphone` as a voice input")
|
||||
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
|
||||
mic = load_audio(mic_audio, tts.input_sample_rate)
|
||||
voice_samples, conditioning_latents = [mic], None
|
||||
elif voice == "random":
|
||||
|
@ -431,4 +431,250 @@ def setup_tortoise(restart=False):
|
|||
print("Initializating TorToiSe...")
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||
print("TorToiSe initialized, ready for generation.")
|
||||
return tts
|
||||
return tts
|
||||
|
||||
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ):
|
||||
settings = {
|
||||
"batch_size": batch_size if batch_size else 128,
|
||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||
"print_rate": print_rate if print_rate else 50,
|
||||
"save_rate": save_rate if save_rate else 50,
|
||||
"name": name if name else "finetune",
|
||||
"dataset_name": dataset_name if dataset_name else "finetune",
|
||||
"dataset_path": dataset_path if dataset_path else "./experiments/finetune/train.txt",
|
||||
"validation_name": validation_name if validation_name else "finetune",
|
||||
"validation_path": validation_path if validation_path else "./experiments/finetune/val.txt",
|
||||
}
|
||||
|
||||
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f:
|
||||
yaml = f.read()
|
||||
|
||||
for k in settings:
|
||||
print(f"${{{k}}} => {settings[k]}")
|
||||
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
||||
|
||||
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
||||
f.write(yaml)
|
||||
|
||||
def reset_generation_settings():
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps({}, indent='\t') )
|
||||
return import_generate_settings()
|
||||
|
||||
def import_voice(file, saveAs = None):
|
||||
global args
|
||||
|
||||
j, latents = read_generate_settings(file, read_latents=True)
|
||||
|
||||
if j is not None and saveAs is None:
|
||||
saveAs = j['voice']
|
||||
if saveAs is None or saveAs == "":
|
||||
raise Exception("Specify a voice name")
|
||||
|
||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
if latents:
|
||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||
f.write(latents)
|
||||
latents = f'{outdir}/cond_latents.pth'
|
||||
print(f"Imported latents to {latents}")
|
||||
else:
|
||||
filename = file.name
|
||||
if filename[-4:] != ".wav":
|
||||
raise Exception("Please convert to a WAV first")
|
||||
|
||||
path = f"{outdir}/{os.path.basename(filename)}"
|
||||
waveform, sampling_rate = torchaudio.load(filename)
|
||||
|
||||
if args.voice_fixer:
|
||||
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
||||
if sampling_rate != 44100:
|
||||
print(f"Resampling imported voice sample: {path}")
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
sampling_rate,
|
||||
44100,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
waveform = resampler(waveform)
|
||||
sampling_rate = 44100
|
||||
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
|
||||
print(f"Running 'voicefixer' on voice sample: {path}")
|
||||
voicefixer.restore(
|
||||
input = path,
|
||||
output = path,
|
||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||
#mode=mode,
|
||||
)
|
||||
else:
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
|
||||
|
||||
print(f"Imported voice to {path}")
|
||||
|
||||
|
||||
def import_generate_settings(file="./config/generate.json"):
|
||||
settings, _ = read_generate_settings(file, read_latents=False)
|
||||
|
||||
if settings is None:
|
||||
return None
|
||||
|
||||
return (
|
||||
None if 'text' not in settings else settings['text'],
|
||||
None if 'delimiter' not in settings else settings['delimiter'],
|
||||
None if 'emotion' not in settings else settings['emotion'],
|
||||
None if 'prompt' not in settings else settings['prompt'],
|
||||
None if 'voice' not in settings else settings['voice'],
|
||||
None,
|
||||
None,
|
||||
None if 'seed' not in settings else settings['seed'],
|
||||
None if 'candidates' not in settings else settings['candidates'],
|
||||
None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'],
|
||||
None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'],
|
||||
0.8 if 'temperature' not in settings else settings['temperature'],
|
||||
"DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'],
|
||||
8 if 'breathing_room' not in settings else settings['breathing_room'],
|
||||
0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'],
|
||||
0.8 if 'top_p' not in settings else settings['top_p'],
|
||||
1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'],
|
||||
1.0 if 'length_penalty' not in settings else settings['length_penalty'],
|
||||
2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'],
|
||||
2.0 if 'cond_free_k' not in settings else settings['cond_free_k'],
|
||||
None if 'experimentals' not in settings else settings['experimentals'],
|
||||
)
|
||||
|
||||
def curl(url):
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
|
||||
conn = urllib.request.urlopen(req)
|
||||
data = conn.read()
|
||||
data = data.decode()
|
||||
data = json.loads(data)
|
||||
conn.close()
|
||||
return data
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def check_for_updates():
|
||||
if not os.path.isfile('./.git/FETCH_HEAD'):
|
||||
print("Cannot check for updates: not from a git repo")
|
||||
return False
|
||||
|
||||
with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
|
||||
head = f.read()
|
||||
|
||||
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
|
||||
if match is None or len(match) == 0:
|
||||
print("Cannot check for updates: cannot parse FETCH_HEAD")
|
||||
return False
|
||||
|
||||
match = match[0]
|
||||
|
||||
local = match[0]
|
||||
host = match[1]
|
||||
owner = match[2]
|
||||
repo = match[3]
|
||||
|
||||
res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances
|
||||
|
||||
if res is None or len(res) == 0:
|
||||
print("Cannot check for updates: cannot fetch from remote")
|
||||
return False
|
||||
|
||||
remote = res[0]["commit"]["id"]
|
||||
|
||||
if remote != local:
|
||||
print(f"New version found: {local[:8]} => {remote[:8]}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reload_tts():
|
||||
global tts
|
||||
del tts
|
||||
tts = setup_tortoise(restart=True)
|
||||
|
||||
def cancel_generate():
|
||||
tortoise.api.STOP_SIGNAL = True
|
||||
|
||||
def get_voice_list(dir=get_voice_dir()):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
|
||||
|
||||
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
|
||||
global args
|
||||
|
||||
args.listen = listen
|
||||
args.share = share
|
||||
args.check_for_updates = check_for_updates
|
||||
args.models_from_local_only = models_from_local_only
|
||||
args.low_vram = low_vram
|
||||
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
||||
args.device_override = device_override
|
||||
args.sample_batch_size = sample_batch_size
|
||||
args.embed_output_metadata = embed_output_metadata
|
||||
args.latents_lean_and_mean = latents_lean_and_mean
|
||||
args.voice_fixer = voice_fixer
|
||||
args.voice_fixer_use_cuda = voice_fixer_use_cuda
|
||||
args.concurrency_count = concurrency_count
|
||||
args.output_sample_rate = output_sample_rate
|
||||
args.output_volume = output_volume
|
||||
|
||||
settings = {
|
||||
'listen': None if args.listen else args.listen,
|
||||
'share': args.share,
|
||||
'low-vram':args.low_vram,
|
||||
'check-for-updates':args.check_for_updates,
|
||||
'models-from-local-only':args.models_from_local_only,
|
||||
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
||||
'device-override': args.device_override,
|
||||
'sample-batch-size': args.sample_batch_size,
|
||||
'embed-output-metadata': args.embed_output_metadata,
|
||||
'latents-lean-and-mean': args.latents_lean_and_mean,
|
||||
'voice-fixer': args.voice_fixer,
|
||||
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
|
||||
'concurrency-count': args.concurrency_count,
|
||||
'output-sample-rate': args.output_sample_rate,
|
||||
'output-volume': args.output_volume,
|
||||
}
|
||||
|
||||
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(settings, indent='\t') )
|
||||
|
||||
def read_generate_settings(file, read_latents=True, read_json=True):
|
||||
j = None
|
||||
latents = None
|
||||
|
||||
if file is not None:
|
||||
if hasattr(file, 'name'):
|
||||
file = file.name
|
||||
|
||||
if file[-4:] == ".wav":
|
||||
metadata = music_tag.load_file(file)
|
||||
if 'lyrics' in metadata:
|
||||
j = json.loads(str(metadata['lyrics']))
|
||||
elif file[-5:] == ".json":
|
||||
with open(file, 'r') as f:
|
||||
j = json.load(f)
|
||||
|
||||
if j is None:
|
||||
print("No metadata found in audio file to read")
|
||||
else:
|
||||
if 'latents' in j:
|
||||
if read_latents:
|
||||
latents = base64.b64decode(j['latents'])
|
||||
del j['latents']
|
||||
|
||||
|
||||
if "time" in j:
|
||||
j["time"] = "{:.3f}".format(j["time"])
|
||||
|
||||
return (
|
||||
j,
|
||||
latents,
|
||||
)
|
389
src/webui.py
389
src/webui.py
|
@ -21,6 +21,71 @@ from utils import *
|
|||
|
||||
args = setup_args()
|
||||
|
||||
def run_generation(
|
||||
text,
|
||||
delimiter,
|
||||
emotion,
|
||||
prompt,
|
||||
voice,
|
||||
mic_audio,
|
||||
voice_latents_chunks,
|
||||
seed,
|
||||
candidates,
|
||||
num_autoregressive_samples,
|
||||
diffusion_iterations,
|
||||
temperature,
|
||||
diffusion_sampler,
|
||||
breathing_room,
|
||||
cvvp_weight,
|
||||
top_p,
|
||||
diffusion_temperature,
|
||||
length_penalty,
|
||||
repetition_penalty,
|
||||
cond_free_k,
|
||||
experimental_checkboxes,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
):
|
||||
try:
|
||||
sample, outputs, stats = generate(
|
||||
text,
|
||||
delimiter,
|
||||
emotion,
|
||||
prompt,
|
||||
voice,
|
||||
mic_audio,
|
||||
voice_latents_chunks,
|
||||
seed,
|
||||
candidates,
|
||||
num_autoregressive_samples,
|
||||
diffusion_iterations,
|
||||
temperature,
|
||||
diffusion_sampler,
|
||||
breathing_room,
|
||||
cvvp_weight,
|
||||
top_p,
|
||||
diffusion_temperature,
|
||||
length_penalty,
|
||||
repetition_penalty,
|
||||
cond_free_k,
|
||||
experimental_checkboxes,
|
||||
progress
|
||||
)
|
||||
except Exception as e:
|
||||
message = str(e)
|
||||
if message == "Kill signal detected":
|
||||
reload_tts()
|
||||
|
||||
raise gr.Error(message)
|
||||
|
||||
|
||||
return (
|
||||
outputs[0],
|
||||
gr.update(value=sample, visible=sample is not None),
|
||||
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
|
||||
gr.update(visible=len(outputs) > 1),
|
||||
gr.update(value=stats, visible=True),
|
||||
)
|
||||
|
||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
global tts
|
||||
global args
|
||||
|
@ -58,230 +123,6 @@ def update_presets(value):
|
|||
else:
|
||||
return (gr.update(), gr.update())
|
||||
|
||||
def read_generate_settings(file, read_latents=True, read_json=True):
|
||||
j = None
|
||||
latents = None
|
||||
|
||||
if file is not None:
|
||||
if hasattr(file, 'name'):
|
||||
file = file.name
|
||||
|
||||
if file[-4:] == ".wav":
|
||||
metadata = music_tag.load_file(file)
|
||||
if 'lyrics' in metadata:
|
||||
j = json.loads(str(metadata['lyrics']))
|
||||
elif file[-5:] == ".json":
|
||||
with open(file, 'r') as f:
|
||||
j = json.load(f)
|
||||
|
||||
if j is None:
|
||||
gr.Error("No metadata found in audio file to read")
|
||||
else:
|
||||
if 'latents' in j:
|
||||
if read_latents:
|
||||
latents = base64.b64decode(j['latents'])
|
||||
del j['latents']
|
||||
|
||||
|
||||
if "time" in j:
|
||||
j["time"] = "{:.3f}".format(j["time"])
|
||||
|
||||
return (
|
||||
j,
|
||||
latents,
|
||||
)
|
||||
|
||||
def import_voice(file, saveAs = None):
|
||||
global args
|
||||
|
||||
j, latents = read_generate_settings(file, read_latents=True)
|
||||
|
||||
if j is not None and saveAs is None:
|
||||
saveAs = j['voice']
|
||||
if saveAs is None or saveAs == "":
|
||||
raise gr.Error("Specify a voice name")
|
||||
|
||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
if latents:
|
||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||
f.write(latents)
|
||||
latents = f'{outdir}/cond_latents.pth'
|
||||
print(f"Imported latents to {latents}")
|
||||
else:
|
||||
filename = file.name
|
||||
if filename[-4:] != ".wav":
|
||||
raise gr.Error("Please convert to a WAV first")
|
||||
|
||||
path = f"{outdir}/{os.path.basename(filename)}"
|
||||
waveform, sampling_rate = torchaudio.load(filename)
|
||||
|
||||
if args.voice_fixer:
|
||||
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
||||
if sampling_rate != 44100:
|
||||
print(f"Resampling imported voice sample: {path}")
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
sampling_rate,
|
||||
44100,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
waveform = resampler(waveform)
|
||||
sampling_rate = 44100
|
||||
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
|
||||
print(f"Running 'voicefixer' on voice sample: {path}")
|
||||
voicefixer.restore(
|
||||
input = path,
|
||||
output = path,
|
||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||
#mode=mode,
|
||||
)
|
||||
else:
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
|
||||
|
||||
print(f"Imported voice to {path}")
|
||||
|
||||
|
||||
def import_generate_settings(file="./config/generate.json"):
|
||||
settings, _ = read_generate_settings(file, read_latents=False)
|
||||
|
||||
if settings is None:
|
||||
return None
|
||||
|
||||
return (
|
||||
None if 'text' not in settings else settings['text'],
|
||||
None if 'delimiter' not in settings else settings['delimiter'],
|
||||
None if 'emotion' not in settings else settings['emotion'],
|
||||
None if 'prompt' not in settings else settings['prompt'],
|
||||
None if 'voice' not in settings else settings['voice'],
|
||||
None,
|
||||
None,
|
||||
None if 'seed' not in settings else settings['seed'],
|
||||
None if 'candidates' not in settings else settings['candidates'],
|
||||
None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'],
|
||||
None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'],
|
||||
0.8 if 'temperature' not in settings else settings['temperature'],
|
||||
"DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'],
|
||||
8 if 'breathing_room' not in settings else settings['breathing_room'],
|
||||
0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'],
|
||||
0.8 if 'top_p' not in settings else settings['top_p'],
|
||||
1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'],
|
||||
1.0 if 'length_penalty' not in settings else settings['length_penalty'],
|
||||
2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'],
|
||||
2.0 if 'cond_free_k' not in settings else settings['cond_free_k'],
|
||||
None if 'experimentals' not in settings else settings['experimentals'],
|
||||
)
|
||||
|
||||
def curl(url):
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
|
||||
conn = urllib.request.urlopen(req)
|
||||
data = conn.read()
|
||||
data = data.decode()
|
||||
data = json.loads(data)
|
||||
conn.close()
|
||||
return data
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def check_for_updates():
|
||||
if not os.path.isfile('./.git/FETCH_HEAD'):
|
||||
print("Cannot check for updates: not from a git repo")
|
||||
return False
|
||||
|
||||
with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
|
||||
head = f.read()
|
||||
|
||||
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
|
||||
if match is None or len(match) == 0:
|
||||
print("Cannot check for updates: cannot parse FETCH_HEAD")
|
||||
return False
|
||||
|
||||
match = match[0]
|
||||
|
||||
local = match[0]
|
||||
host = match[1]
|
||||
owner = match[2]
|
||||
repo = match[3]
|
||||
|
||||
res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances
|
||||
|
||||
if res is None or len(res) == 0:
|
||||
print("Cannot check for updates: cannot fetch from remote")
|
||||
return False
|
||||
|
||||
remote = res[0]["commit"]["id"]
|
||||
|
||||
if remote != local:
|
||||
print(f"New version found: {local[:8]} => {remote[:8]}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reload_tts():
|
||||
global tts
|
||||
del tts
|
||||
tts = setup_tortoise(restart=True)
|
||||
|
||||
def cancel_generate():
|
||||
tortoise.api.STOP_SIGNAL = True
|
||||
|
||||
def get_voice_list(dir=get_voice_dir()):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
|
||||
|
||||
def update_voices():
|
||||
return (
|
||||
gr.Dropdown.update(choices=get_voice_list()),
|
||||
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
||||
)
|
||||
|
||||
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
|
||||
global args
|
||||
|
||||
args.listen = listen
|
||||
args.share = share
|
||||
args.check_for_updates = check_for_updates
|
||||
args.models_from_local_only = models_from_local_only
|
||||
args.low_vram = low_vram
|
||||
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
||||
args.device_override = device_override
|
||||
args.sample_batch_size = sample_batch_size
|
||||
args.embed_output_metadata = embed_output_metadata
|
||||
args.latents_lean_and_mean = latents_lean_and_mean
|
||||
args.voice_fixer = voice_fixer
|
||||
args.voice_fixer_use_cuda = voice_fixer_use_cuda
|
||||
args.concurrency_count = concurrency_count
|
||||
args.output_sample_rate = output_sample_rate
|
||||
args.output_volume = output_volume
|
||||
|
||||
settings = {
|
||||
'listen': None if args.listen else args.listen,
|
||||
'share': args.share,
|
||||
'low-vram':args.low_vram,
|
||||
'check-for-updates':args.check_for_updates,
|
||||
'models-from-local-only':args.models_from_local_only,
|
||||
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
||||
'device-override': args.device_override,
|
||||
'sample-batch-size': args.sample_batch_size,
|
||||
'embed-output-metadata': args.embed_output_metadata,
|
||||
'latents-lean-and-mean': args.latents_lean_and_mean,
|
||||
'voice-fixer': args.voice_fixer,
|
||||
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
|
||||
'concurrency-count': args.concurrency_count,
|
||||
'output-sample-rate': args.output_sample_rate,
|
||||
'output-volume': args.output_volume,
|
||||
}
|
||||
|
||||
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(settings, indent='\t') )
|
||||
|
||||
def setup_gradio():
|
||||
global args
|
||||
global ui
|
||||
|
@ -528,6 +369,29 @@ def setup_gradio():
|
|||
import_voice_name,
|
||||
]
|
||||
)
|
||||
with gr.Tab("Training"):
|
||||
with gr.Tab("Configuration"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_settings = [
|
||||
gr.Slider(label="Batch Size", value=128),
|
||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||
gr.Number(label="Print Frequency", value=50),
|
||||
gr.Number(label="Save Frequency", value=50),
|
||||
]
|
||||
with gr.Column():
|
||||
training_settings = training_settings + [
|
||||
gr.Textbox(label="Training Name", placeholder="finetune"),
|
||||
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
||||
gr.Textbox(label="Dataset Path", placeholder="./experiments/finetune/train.txt"),
|
||||
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
||||
gr.Textbox(label="Validation Path", placeholder="./experiments/finetune/val.txt"),
|
||||
]
|
||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||
save_yaml_button.click(save_training_settings,
|
||||
inputs=training_settings,
|
||||
outputs=None
|
||||
)
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
exec_inputs = []
|
||||
|
@ -586,71 +450,15 @@ def setup_gradio():
|
|||
]
|
||||
|
||||
# YUCK
|
||||
def run_generation(
|
||||
text,
|
||||
delimiter,
|
||||
emotion,
|
||||
prompt,
|
||||
voice,
|
||||
mic_audio,
|
||||
voice_latents_chunks,
|
||||
seed,
|
||||
candidates,
|
||||
num_autoregressive_samples,
|
||||
diffusion_iterations,
|
||||
temperature,
|
||||
diffusion_sampler,
|
||||
breathing_room,
|
||||
cvvp_weight,
|
||||
top_p,
|
||||
diffusion_temperature,
|
||||
length_penalty,
|
||||
repetition_penalty,
|
||||
cond_free_k,
|
||||
experimental_checkboxes,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
):
|
||||
try:
|
||||
sample, outputs, stats = generate(
|
||||
text,
|
||||
delimiter,
|
||||
emotion,
|
||||
prompt,
|
||||
voice,
|
||||
mic_audio,
|
||||
voice_latents_chunks,
|
||||
seed,
|
||||
candidates,
|
||||
num_autoregressive_samples,
|
||||
diffusion_iterations,
|
||||
temperature,
|
||||
diffusion_sampler,
|
||||
breathing_room,
|
||||
cvvp_weight,
|
||||
top_p,
|
||||
diffusion_temperature,
|
||||
length_penalty,
|
||||
repetition_penalty,
|
||||
cond_free_k,
|
||||
experimental_checkboxes,
|
||||
progress
|
||||
)
|
||||
except Exception as e:
|
||||
message = str(e)
|
||||
if message == "Kill signal detected":
|
||||
reload_tts()
|
||||
|
||||
raise gr.Error(message)
|
||||
|
||||
|
||||
def update_voices():
|
||||
return (
|
||||
outputs[0],
|
||||
gr.update(value=sample, visible=sample is not None),
|
||||
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
|
||||
gr.update(visible=len(outputs) > 1),
|
||||
gr.update(value=stats, visible=True),
|
||||
gr.Dropdown.update(choices=get_voice_list()),
|
||||
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
||||
)
|
||||
|
||||
def history_copy_settings( voice, file ):
|
||||
return import_generate_settings( f"./results/{voice}/{file}" )
|
||||
|
||||
refresh_voices.click(update_voices,
|
||||
inputs=None,
|
||||
outputs=[
|
||||
|
@ -681,21 +489,12 @@ def setup_gradio():
|
|||
outputs=input_settings
|
||||
)
|
||||
|
||||
def reset_generation_settings():
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps({}, indent='\t') )
|
||||
return import_generate_settings()
|
||||
|
||||
reset_generation_settings_button.click(
|
||||
fn=reset_generation_settings,
|
||||
inputs=None,
|
||||
outputs=input_settings
|
||||
)
|
||||
|
||||
def history_copy_settings( voice, file ):
|
||||
settings = import_generate_settings( f"./results/{voice}/{file}" )
|
||||
return settings
|
||||
|
||||
history_copy_settings_button.click(history_copy_settings,
|
||||
inputs=[
|
||||
history_voices,
|
||||
|
|
144
training/.template.yaml
Executable file
144
training/.template.yaml
Executable file
|
@ -0,0 +1,144 @@
|
|||
name: ${name}
|
||||
model: extensibletrainer
|
||||
scale: 1
|
||||
gpu_ids: [0] # <-- unless you have multiple gpus, use this
|
||||
start_step: -1
|
||||
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
||||
fp16: false # might want to check this out
|
||||
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
|
||||
use_tb_logger: true
|
||||
|
||||
datasets:
|
||||
train:
|
||||
name: ${dataset_name}
|
||||
n_workers: 8 # idk what this does
|
||||
batch_size: ${batch_size} # This leads to ~16GB of vram usage on my 3090.
|
||||
mode: paired_voice_audio
|
||||
path: ${dataset_path}
|
||||
fetcher_mode: ['lj'] # CHANGEME if your dataset isn't in LJSpeech format
|
||||
phase: train
|
||||
max_wav_length: 255995
|
||||
max_text_length: 200
|
||||
sample_rate: 22050
|
||||
load_conditioning: True
|
||||
num_conditioning_candidates: 2
|
||||
conditioning_length: 44000
|
||||
use_bpe_tokenizer: True
|
||||
load_aligned_codes: False
|
||||
val:
|
||||
name: ${validation_name}
|
||||
n_workers: 1
|
||||
batch_size: 32 # this could be higher probably
|
||||
mode: paired_voice_audio
|
||||
path: ${validation_path}
|
||||
fetcher_mode: ['lj']
|
||||
phase: val # might be broken idk
|
||||
max_wav_length: 255995
|
||||
max_text_length: 200
|
||||
sample_rate: 22050
|
||||
load_conditioning: True
|
||||
num_conditioning_candidates: 2
|
||||
conditioning_length: 44000
|
||||
use_bpe_tokenizer: True
|
||||
load_aligned_codes: False
|
||||
|
||||
steps:
|
||||
gpt_train:
|
||||
training: gpt
|
||||
loss_log_buffer: 500 # no idea what this does
|
||||
|
||||
# Generally follows the recipe from the DALLE paper.
|
||||
optimizer: adamw # this should be adamw_zero if you're using distributed training
|
||||
optimizer_params:
|
||||
lr: !!float ${learning_rate} # CHANGEME: this was originally 1e-4; I reduced it to 1e-5 because it's fine-tuning, but **you should experiment with this value**
|
||||
weight_decay: !!float 1e-2
|
||||
beta1: 0.9
|
||||
beta2: 0.96
|
||||
clip_grad_eps: 4
|
||||
|
||||
injectors: # TODO: replace this entire sequence with the GptVoiceLatentInjector
|
||||
paired_to_mel:
|
||||
type: torch_mel_spectrogram
|
||||
mel_norm_file: ./experiments/clips_mel_norms.pth
|
||||
in: wav
|
||||
out: paired_mel
|
||||
paired_cond_to_mel:
|
||||
type: for_each
|
||||
subtype: torch_mel_spectrogram
|
||||
mel_norm_file: ./experiments/clips_mel_norms.pth
|
||||
in: conditioning
|
||||
out: paired_conditioning_mel
|
||||
to_codes:
|
||||
type: discrete_token
|
||||
in: paired_mel
|
||||
out: paired_mel_codes
|
||||
dvae_config: "./experiments/train_diffusion_vocoder_22k_level.yml" # EXTREMELY IMPORTANT
|
||||
paired_fwd_text:
|
||||
type: generator
|
||||
generator: gpt
|
||||
in: [paired_conditioning_mel, padded_text, text_lengths, paired_mel_codes, wav_lengths]
|
||||
out: [loss_text_ce, loss_mel_ce, logits]
|
||||
losses:
|
||||
text_ce:
|
||||
type: direct
|
||||
weight: .01
|
||||
key: loss_text_ce
|
||||
mel_ce:
|
||||
type: direct
|
||||
weight: 1
|
||||
key: loss_mel_ce
|
||||
|
||||
networks:
|
||||
gpt:
|
||||
type: generator
|
||||
which_model_G: unified_voice2 # none of the unified_voice*.py files actually match the tortoise inference code... 4 and 3 have "alignment_head" (wtf is that?), 2 lacks the types=1 parameter.
|
||||
kwargs:
|
||||
layers: 30 # WAS 8
|
||||
model_dim: 1024 # WAS 512
|
||||
heads: 16 # WAS 8
|
||||
max_text_tokens: 402 # WAS 120
|
||||
max_mel_tokens: 604 # WAS 250
|
||||
max_conditioning_inputs: 2 # WAS 1
|
||||
mel_length_compression: 1024
|
||||
number_text_tokens: 256 # supposed to be 255 for newer unified_voice files
|
||||
number_mel_codes: 8194
|
||||
start_mel_token: 8192
|
||||
stop_mel_token: 8193
|
||||
start_text_token: 255
|
||||
train_solo_embeddings: False # missing in uv3/4
|
||||
use_mel_codes_as_input: True # ditto
|
||||
checkpointing: True
|
||||
#types: 1 # this is MISSING, but in my analysis 1 is equivalent to not having it.
|
||||
#only_alignment_head: False # uv3/4
|
||||
|
||||
path:
|
||||
pretrain_model_gpt: './experiments/autoregressive.pth' # CHANGEME: copy this from tortoise cache
|
||||
strict_load: true
|
||||
#resume_state: ./experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||
|
||||
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
|
||||
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
|
||||
niter: 50000
|
||||
warmup_iter: -1
|
||||
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||
val_freq: 500
|
||||
|
||||
default_lr_scheme: MultiStepLR
|
||||
gen_lr_steps: [500, 1000, 1400, 1800] #[50000, 100000, 140000, 180000]
|
||||
lr_gamma: 0.5
|
||||
|
||||
eval:
|
||||
output_state: gen
|
||||
injectors:
|
||||
gen_inj_eval:
|
||||
type: generator
|
||||
generator: generator
|
||||
in: hq
|
||||
out: [gen, codebook_commitment_loss]
|
||||
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow
|
||||
visuals: [gen, mel]
|
||||
visual_debug_rate: 500
|
||||
is_mel_spectrogram: true
|
Loading…
Reference in New Issue
Block a user