forked from mrq/ai-voice-cloning
tab to generate the training YAML
This commit is contained in:
parent
3a078df95e
commit
f8249aa826
21
README.md
21
README.md
|
@ -221,6 +221,27 @@ If you want to reuse its generation settings, simply click `Copy Settings`.
|
||||||
|
|
||||||
To import a voice, click `Import Voice`. Remember to click `Refresh Voice List` in the `Generate` panel afterwards, if it's a new voice.
|
To import a voice, click `Import Voice`. Remember to click `Refresh Voice List` in the `Generate` panel afterwards, if it's a new voice.
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
This tab will contain a collection of sub-tabs pertaining to training.
|
||||||
|
|
||||||
|
#### Configuration
|
||||||
|
|
||||||
|
This will generate the YAML necessary to feed into training. For now, you can set:
|
||||||
|
* `Batch Size`: size of batches for training, more batches = faster training, at the cost of higher VRAM. setting this to 1 will lead to problems
|
||||||
|
* `Learning Rate`: how large changes to training will be made, lower values = better over the long term, while higher values will fry a model fast. For fine-tuning, the default *should* be fine, but in the future, a learning rate scheduler would be better (have a higher learning rate initially, then step it down over enough steps/epochs)
|
||||||
|
* `Print Frequency`: how often to print (I assume)
|
||||||
|
* `Save Frequency`: how often to save checkpoints
|
||||||
|
* `Training Name`: name to save the configuration as, as well as the training script to create the folder under
|
||||||
|
* `Dataset Name`: **!**TODO**!**: fill
|
||||||
|
* `Dataset Path`: path to the input training text file. For LJSpeech-esque datasets, this is to a textfile formatted like:
|
||||||
|
```
|
||||||
|
wavs/LJ001-0001.wav|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition
|
||||||
|
wavs/LJ001-0002.wav|in being comparatively modern.|in being comparatively modern.
|
||||||
|
```
|
||||||
|
* `Validation Name`: **!**TODO**!**: fill
|
||||||
|
* `Validation Path`: path for the validation set, similar to the dataset. I'm not necessarily sure what to really use for this, so explicitly for testing, I just copied the training dataset text
|
||||||
|
|
||||||
### Settings
|
### Settings
|
||||||
|
|
||||||
This tab (should) hold a bunch of other settings, from tunables that shouldn't be tampered with, to settings pertaining to the web UI itself.
|
This tab (should) hold a bunch of other settings, from tunables that shouldn't be tampered with, to settings pertaining to the web UI itself.
|
||||||
|
|
35
src/train.py
Executable file
35
src/train.py
Executable file
|
@ -0,0 +1,35 @@
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from ..dlas.codes import *
|
||||||
|
from ..dlas.codes.utils import util, options as option
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||||
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
|
args = parser.parse_args()
|
||||||
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
if args.launcher != 'none':
|
||||||
|
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
||||||
|
if 'gpu_ids' in opt.keys():
|
||||||
|
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||||
|
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||||
|
trainer = Trainer()
|
||||||
|
|
||||||
|
#### distributed training settings
|
||||||
|
if args.launcher == 'none': # disabled distributed training
|
||||||
|
opt['dist'] = False
|
||||||
|
trainer.rank = -1
|
||||||
|
if len(opt['gpu_ids']) == 1:
|
||||||
|
torch.cuda.set_device(opt['gpu_ids'][0])
|
||||||
|
print('Disabled distributed training.')
|
||||||
|
else:
|
||||||
|
opt['dist'] = True
|
||||||
|
init_dist('nccl')
|
||||||
|
trainer.world_size = torch.distributed.get_world_size()
|
||||||
|
trainer.rank = torch.distributed.get_rank()
|
||||||
|
torch.cuda.set_device(torch.distributed.get_rank())
|
||||||
|
|
||||||
|
trainer.init(args.opt, opt, args.launcher)
|
||||||
|
trainer.do_training()
|
252
src/utils.py
252
src/utils.py
|
@ -138,7 +138,7 @@ def generate(
|
||||||
try:
|
try:
|
||||||
tts
|
tts
|
||||||
except NameError:
|
except NameError:
|
||||||
raise gr.Error("TTS is still initializing...")
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
if voice != "microphone":
|
if voice != "microphone":
|
||||||
voices = [voice]
|
voices = [voice]
|
||||||
|
@ -147,7 +147,7 @@ def generate(
|
||||||
|
|
||||||
if voice == "microphone":
|
if voice == "microphone":
|
||||||
if mic_audio is None:
|
if mic_audio is None:
|
||||||
raise gr.Error("Please provide audio from mic when choosing `microphone` as a voice input")
|
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
|
||||||
mic = load_audio(mic_audio, tts.input_sample_rate)
|
mic = load_audio(mic_audio, tts.input_sample_rate)
|
||||||
voice_samples, conditioning_latents = [mic], None
|
voice_samples, conditioning_latents = [mic], None
|
||||||
elif voice == "random":
|
elif voice == "random":
|
||||||
|
@ -431,4 +431,250 @@ def setup_tortoise(restart=False):
|
||||||
print("Initializating TorToiSe...")
|
print("Initializating TorToiSe...")
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
print("TorToiSe initialized, ready for generation.")
|
print("TorToiSe initialized, ready for generation.")
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
|
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ):
|
||||||
|
settings = {
|
||||||
|
"batch_size": batch_size if batch_size else 128,
|
||||||
|
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||||
|
"print_rate": print_rate if print_rate else 50,
|
||||||
|
"save_rate": save_rate if save_rate else 50,
|
||||||
|
"name": name if name else "finetune",
|
||||||
|
"dataset_name": dataset_name if dataset_name else "finetune",
|
||||||
|
"dataset_path": dataset_path if dataset_path else "./experiments/finetune/train.txt",
|
||||||
|
"validation_name": validation_name if validation_name else "finetune",
|
||||||
|
"validation_path": validation_path if validation_path else "./experiments/finetune/val.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f:
|
||||||
|
yaml = f.read()
|
||||||
|
|
||||||
|
for k in settings:
|
||||||
|
print(f"${{{k}}} => {settings[k]}")
|
||||||
|
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
||||||
|
|
||||||
|
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(yaml)
|
||||||
|
|
||||||
|
def reset_generation_settings():
|
||||||
|
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps({}, indent='\t') )
|
||||||
|
return import_generate_settings()
|
||||||
|
|
||||||
|
def import_voice(file, saveAs = None):
|
||||||
|
global args
|
||||||
|
|
||||||
|
j, latents = read_generate_settings(file, read_latents=True)
|
||||||
|
|
||||||
|
if j is not None and saveAs is None:
|
||||||
|
saveAs = j['voice']
|
||||||
|
if saveAs is None or saveAs == "":
|
||||||
|
raise Exception("Specify a voice name")
|
||||||
|
|
||||||
|
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
if latents:
|
||||||
|
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||||
|
f.write(latents)
|
||||||
|
latents = f'{outdir}/cond_latents.pth'
|
||||||
|
print(f"Imported latents to {latents}")
|
||||||
|
else:
|
||||||
|
filename = file.name
|
||||||
|
if filename[-4:] != ".wav":
|
||||||
|
raise Exception("Please convert to a WAV first")
|
||||||
|
|
||||||
|
path = f"{outdir}/{os.path.basename(filename)}"
|
||||||
|
waveform, sampling_rate = torchaudio.load(filename)
|
||||||
|
|
||||||
|
if args.voice_fixer:
|
||||||
|
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
||||||
|
if sampling_rate != 44100:
|
||||||
|
print(f"Resampling imported voice sample: {path}")
|
||||||
|
resampler = torchaudio.transforms.Resample(
|
||||||
|
sampling_rate,
|
||||||
|
44100,
|
||||||
|
lowpass_filter_width=16,
|
||||||
|
rolloff=0.85,
|
||||||
|
resampling_method="kaiser_window",
|
||||||
|
beta=8.555504641634386,
|
||||||
|
)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
sampling_rate = 44100
|
||||||
|
|
||||||
|
torchaudio.save(path, waveform, sampling_rate)
|
||||||
|
|
||||||
|
print(f"Running 'voicefixer' on voice sample: {path}")
|
||||||
|
voicefixer.restore(
|
||||||
|
input = path,
|
||||||
|
output = path,
|
||||||
|
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||||
|
#mode=mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torchaudio.save(path, waveform, sampling_rate)
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Imported voice to {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def import_generate_settings(file="./config/generate.json"):
|
||||||
|
settings, _ = read_generate_settings(file, read_latents=False)
|
||||||
|
|
||||||
|
if settings is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (
|
||||||
|
None if 'text' not in settings else settings['text'],
|
||||||
|
None if 'delimiter' not in settings else settings['delimiter'],
|
||||||
|
None if 'emotion' not in settings else settings['emotion'],
|
||||||
|
None if 'prompt' not in settings else settings['prompt'],
|
||||||
|
None if 'voice' not in settings else settings['voice'],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None if 'seed' not in settings else settings['seed'],
|
||||||
|
None if 'candidates' not in settings else settings['candidates'],
|
||||||
|
None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'],
|
||||||
|
None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'],
|
||||||
|
0.8 if 'temperature' not in settings else settings['temperature'],
|
||||||
|
"DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'],
|
||||||
|
8 if 'breathing_room' not in settings else settings['breathing_room'],
|
||||||
|
0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'],
|
||||||
|
0.8 if 'top_p' not in settings else settings['top_p'],
|
||||||
|
1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'],
|
||||||
|
1.0 if 'length_penalty' not in settings else settings['length_penalty'],
|
||||||
|
2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'],
|
||||||
|
2.0 if 'cond_free_k' not in settings else settings['cond_free_k'],
|
||||||
|
None if 'experimentals' not in settings else settings['experimentals'],
|
||||||
|
)
|
||||||
|
|
||||||
|
def curl(url):
|
||||||
|
try:
|
||||||
|
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
|
||||||
|
conn = urllib.request.urlopen(req)
|
||||||
|
data = conn.read()
|
||||||
|
data = data.decode()
|
||||||
|
data = json.loads(data)
|
||||||
|
conn.close()
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_for_updates():
|
||||||
|
if not os.path.isfile('./.git/FETCH_HEAD'):
|
||||||
|
print("Cannot check for updates: not from a git repo")
|
||||||
|
return False
|
||||||
|
|
||||||
|
with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
|
||||||
|
head = f.read()
|
||||||
|
|
||||||
|
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
|
||||||
|
if match is None or len(match) == 0:
|
||||||
|
print("Cannot check for updates: cannot parse FETCH_HEAD")
|
||||||
|
return False
|
||||||
|
|
||||||
|
match = match[0]
|
||||||
|
|
||||||
|
local = match[0]
|
||||||
|
host = match[1]
|
||||||
|
owner = match[2]
|
||||||
|
repo = match[3]
|
||||||
|
|
||||||
|
res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances
|
||||||
|
|
||||||
|
if res is None or len(res) == 0:
|
||||||
|
print("Cannot check for updates: cannot fetch from remote")
|
||||||
|
return False
|
||||||
|
|
||||||
|
remote = res[0]["commit"]["id"]
|
||||||
|
|
||||||
|
if remote != local:
|
||||||
|
print(f"New version found: {local[:8]} => {remote[:8]}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reload_tts():
|
||||||
|
global tts
|
||||||
|
del tts
|
||||||
|
tts = setup_tortoise(restart=True)
|
||||||
|
|
||||||
|
def cancel_generate():
|
||||||
|
tortoise.api.STOP_SIGNAL = True
|
||||||
|
|
||||||
|
def get_voice_list(dir=get_voice_dir()):
|
||||||
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
|
||||||
|
|
||||||
|
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
|
||||||
|
global args
|
||||||
|
|
||||||
|
args.listen = listen
|
||||||
|
args.share = share
|
||||||
|
args.check_for_updates = check_for_updates
|
||||||
|
args.models_from_local_only = models_from_local_only
|
||||||
|
args.low_vram = low_vram
|
||||||
|
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
||||||
|
args.device_override = device_override
|
||||||
|
args.sample_batch_size = sample_batch_size
|
||||||
|
args.embed_output_metadata = embed_output_metadata
|
||||||
|
args.latents_lean_and_mean = latents_lean_and_mean
|
||||||
|
args.voice_fixer = voice_fixer
|
||||||
|
args.voice_fixer_use_cuda = voice_fixer_use_cuda
|
||||||
|
args.concurrency_count = concurrency_count
|
||||||
|
args.output_sample_rate = output_sample_rate
|
||||||
|
args.output_volume = output_volume
|
||||||
|
|
||||||
|
settings = {
|
||||||
|
'listen': None if args.listen else args.listen,
|
||||||
|
'share': args.share,
|
||||||
|
'low-vram':args.low_vram,
|
||||||
|
'check-for-updates':args.check_for_updates,
|
||||||
|
'models-from-local-only':args.models_from_local_only,
|
||||||
|
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
||||||
|
'device-override': args.device_override,
|
||||||
|
'sample-batch-size': args.sample_batch_size,
|
||||||
|
'embed-output-metadata': args.embed_output_metadata,
|
||||||
|
'latents-lean-and-mean': args.latents_lean_and_mean,
|
||||||
|
'voice-fixer': args.voice_fixer,
|
||||||
|
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
|
||||||
|
'concurrency-count': args.concurrency_count,
|
||||||
|
'output-sample-rate': args.output_sample_rate,
|
||||||
|
'output-volume': args.output_volume,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(settings, indent='\t') )
|
||||||
|
|
||||||
|
def read_generate_settings(file, read_latents=True, read_json=True):
|
||||||
|
j = None
|
||||||
|
latents = None
|
||||||
|
|
||||||
|
if file is not None:
|
||||||
|
if hasattr(file, 'name'):
|
||||||
|
file = file.name
|
||||||
|
|
||||||
|
if file[-4:] == ".wav":
|
||||||
|
metadata = music_tag.load_file(file)
|
||||||
|
if 'lyrics' in metadata:
|
||||||
|
j = json.loads(str(metadata['lyrics']))
|
||||||
|
elif file[-5:] == ".json":
|
||||||
|
with open(file, 'r') as f:
|
||||||
|
j = json.load(f)
|
||||||
|
|
||||||
|
if j is None:
|
||||||
|
print("No metadata found in audio file to read")
|
||||||
|
else:
|
||||||
|
if 'latents' in j:
|
||||||
|
if read_latents:
|
||||||
|
latents = base64.b64decode(j['latents'])
|
||||||
|
del j['latents']
|
||||||
|
|
||||||
|
|
||||||
|
if "time" in j:
|
||||||
|
j["time"] = "{:.3f}".format(j["time"])
|
||||||
|
|
||||||
|
return (
|
||||||
|
j,
|
||||||
|
latents,
|
||||||
|
)
|
389
src/webui.py
389
src/webui.py
|
@ -21,6 +21,71 @@ from utils import *
|
||||||
|
|
||||||
args = setup_args()
|
args = setup_args()
|
||||||
|
|
||||||
|
def run_generation(
|
||||||
|
text,
|
||||||
|
delimiter,
|
||||||
|
emotion,
|
||||||
|
prompt,
|
||||||
|
voice,
|
||||||
|
mic_audio,
|
||||||
|
voice_latents_chunks,
|
||||||
|
seed,
|
||||||
|
candidates,
|
||||||
|
num_autoregressive_samples,
|
||||||
|
diffusion_iterations,
|
||||||
|
temperature,
|
||||||
|
diffusion_sampler,
|
||||||
|
breathing_room,
|
||||||
|
cvvp_weight,
|
||||||
|
top_p,
|
||||||
|
diffusion_temperature,
|
||||||
|
length_penalty,
|
||||||
|
repetition_penalty,
|
||||||
|
cond_free_k,
|
||||||
|
experimental_checkboxes,
|
||||||
|
progress=gr.Progress(track_tqdm=True)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
sample, outputs, stats = generate(
|
||||||
|
text,
|
||||||
|
delimiter,
|
||||||
|
emotion,
|
||||||
|
prompt,
|
||||||
|
voice,
|
||||||
|
mic_audio,
|
||||||
|
voice_latents_chunks,
|
||||||
|
seed,
|
||||||
|
candidates,
|
||||||
|
num_autoregressive_samples,
|
||||||
|
diffusion_iterations,
|
||||||
|
temperature,
|
||||||
|
diffusion_sampler,
|
||||||
|
breathing_room,
|
||||||
|
cvvp_weight,
|
||||||
|
top_p,
|
||||||
|
diffusion_temperature,
|
||||||
|
length_penalty,
|
||||||
|
repetition_penalty,
|
||||||
|
cond_free_k,
|
||||||
|
experimental_checkboxes,
|
||||||
|
progress
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
message = str(e)
|
||||||
|
if message == "Kill signal detected":
|
||||||
|
reload_tts()
|
||||||
|
|
||||||
|
raise gr.Error(message)
|
||||||
|
|
||||||
|
|
||||||
|
return (
|
||||||
|
outputs[0],
|
||||||
|
gr.update(value=sample, visible=sample is not None),
|
||||||
|
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
|
||||||
|
gr.update(visible=len(outputs) > 1),
|
||||||
|
gr.update(value=stats, visible=True),
|
||||||
|
)
|
||||||
|
|
||||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
global tts
|
global tts
|
||||||
global args
|
global args
|
||||||
|
@ -58,230 +123,6 @@ def update_presets(value):
|
||||||
else:
|
else:
|
||||||
return (gr.update(), gr.update())
|
return (gr.update(), gr.update())
|
||||||
|
|
||||||
def read_generate_settings(file, read_latents=True, read_json=True):
|
|
||||||
j = None
|
|
||||||
latents = None
|
|
||||||
|
|
||||||
if file is not None:
|
|
||||||
if hasattr(file, 'name'):
|
|
||||||
file = file.name
|
|
||||||
|
|
||||||
if file[-4:] == ".wav":
|
|
||||||
metadata = music_tag.load_file(file)
|
|
||||||
if 'lyrics' in metadata:
|
|
||||||
j = json.loads(str(metadata['lyrics']))
|
|
||||||
elif file[-5:] == ".json":
|
|
||||||
with open(file, 'r') as f:
|
|
||||||
j = json.load(f)
|
|
||||||
|
|
||||||
if j is None:
|
|
||||||
gr.Error("No metadata found in audio file to read")
|
|
||||||
else:
|
|
||||||
if 'latents' in j:
|
|
||||||
if read_latents:
|
|
||||||
latents = base64.b64decode(j['latents'])
|
|
||||||
del j['latents']
|
|
||||||
|
|
||||||
|
|
||||||
if "time" in j:
|
|
||||||
j["time"] = "{:.3f}".format(j["time"])
|
|
||||||
|
|
||||||
return (
|
|
||||||
j,
|
|
||||||
latents,
|
|
||||||
)
|
|
||||||
|
|
||||||
def import_voice(file, saveAs = None):
|
|
||||||
global args
|
|
||||||
|
|
||||||
j, latents = read_generate_settings(file, read_latents=True)
|
|
||||||
|
|
||||||
if j is not None and saveAs is None:
|
|
||||||
saveAs = j['voice']
|
|
||||||
if saveAs is None or saveAs == "":
|
|
||||||
raise gr.Error("Specify a voice name")
|
|
||||||
|
|
||||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
|
||||||
if latents:
|
|
||||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
|
||||||
f.write(latents)
|
|
||||||
latents = f'{outdir}/cond_latents.pth'
|
|
||||||
print(f"Imported latents to {latents}")
|
|
||||||
else:
|
|
||||||
filename = file.name
|
|
||||||
if filename[-4:] != ".wav":
|
|
||||||
raise gr.Error("Please convert to a WAV first")
|
|
||||||
|
|
||||||
path = f"{outdir}/{os.path.basename(filename)}"
|
|
||||||
waveform, sampling_rate = torchaudio.load(filename)
|
|
||||||
|
|
||||||
if args.voice_fixer:
|
|
||||||
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
|
||||||
if sampling_rate != 44100:
|
|
||||||
print(f"Resampling imported voice sample: {path}")
|
|
||||||
resampler = torchaudio.transforms.Resample(
|
|
||||||
sampling_rate,
|
|
||||||
44100,
|
|
||||||
lowpass_filter_width=16,
|
|
||||||
rolloff=0.85,
|
|
||||||
resampling_method="kaiser_window",
|
|
||||||
beta=8.555504641634386,
|
|
||||||
)
|
|
||||||
waveform = resampler(waveform)
|
|
||||||
sampling_rate = 44100
|
|
||||||
|
|
||||||
torchaudio.save(path, waveform, sampling_rate)
|
|
||||||
|
|
||||||
print(f"Running 'voicefixer' on voice sample: {path}")
|
|
||||||
voicefixer.restore(
|
|
||||||
input = path,
|
|
||||||
output = path,
|
|
||||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
|
||||||
#mode=mode,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
torchaudio.save(path, waveform, sampling_rate)
|
|
||||||
|
|
||||||
|
|
||||||
print(f"Imported voice to {path}")
|
|
||||||
|
|
||||||
|
|
||||||
def import_generate_settings(file="./config/generate.json"):
|
|
||||||
settings, _ = read_generate_settings(file, read_latents=False)
|
|
||||||
|
|
||||||
if settings is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return (
|
|
||||||
None if 'text' not in settings else settings['text'],
|
|
||||||
None if 'delimiter' not in settings else settings['delimiter'],
|
|
||||||
None if 'emotion' not in settings else settings['emotion'],
|
|
||||||
None if 'prompt' not in settings else settings['prompt'],
|
|
||||||
None if 'voice' not in settings else settings['voice'],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None if 'seed' not in settings else settings['seed'],
|
|
||||||
None if 'candidates' not in settings else settings['candidates'],
|
|
||||||
None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'],
|
|
||||||
None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'],
|
|
||||||
0.8 if 'temperature' not in settings else settings['temperature'],
|
|
||||||
"DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'],
|
|
||||||
8 if 'breathing_room' not in settings else settings['breathing_room'],
|
|
||||||
0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'],
|
|
||||||
0.8 if 'top_p' not in settings else settings['top_p'],
|
|
||||||
1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'],
|
|
||||||
1.0 if 'length_penalty' not in settings else settings['length_penalty'],
|
|
||||||
2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'],
|
|
||||||
2.0 if 'cond_free_k' not in settings else settings['cond_free_k'],
|
|
||||||
None if 'experimentals' not in settings else settings['experimentals'],
|
|
||||||
)
|
|
||||||
|
|
||||||
def curl(url):
|
|
||||||
try:
|
|
||||||
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
|
|
||||||
conn = urllib.request.urlopen(req)
|
|
||||||
data = conn.read()
|
|
||||||
data = data.decode()
|
|
||||||
data = json.loads(data)
|
|
||||||
conn.close()
|
|
||||||
return data
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def check_for_updates():
|
|
||||||
if not os.path.isfile('./.git/FETCH_HEAD'):
|
|
||||||
print("Cannot check for updates: not from a git repo")
|
|
||||||
return False
|
|
||||||
|
|
||||||
with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f:
|
|
||||||
head = f.read()
|
|
||||||
|
|
||||||
match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head)
|
|
||||||
if match is None or len(match) == 0:
|
|
||||||
print("Cannot check for updates: cannot parse FETCH_HEAD")
|
|
||||||
return False
|
|
||||||
|
|
||||||
match = match[0]
|
|
||||||
|
|
||||||
local = match[0]
|
|
||||||
host = match[1]
|
|
||||||
owner = match[2]
|
|
||||||
repo = match[3]
|
|
||||||
|
|
||||||
res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances
|
|
||||||
|
|
||||||
if res is None or len(res) == 0:
|
|
||||||
print("Cannot check for updates: cannot fetch from remote")
|
|
||||||
return False
|
|
||||||
|
|
||||||
remote = res[0]["commit"]["id"]
|
|
||||||
|
|
||||||
if remote != local:
|
|
||||||
print(f"New version found: {local[:8]} => {remote[:8]}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def reload_tts():
|
|
||||||
global tts
|
|
||||||
del tts
|
|
||||||
tts = setup_tortoise(restart=True)
|
|
||||||
|
|
||||||
def cancel_generate():
|
|
||||||
tortoise.api.STOP_SIGNAL = True
|
|
||||||
|
|
||||||
def get_voice_list(dir=get_voice_dir()):
|
|
||||||
os.makedirs(dir, exist_ok=True)
|
|
||||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"]
|
|
||||||
|
|
||||||
def update_voices():
|
|
||||||
return (
|
|
||||||
gr.Dropdown.update(choices=get_voice_list()),
|
|
||||||
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
|
||||||
)
|
|
||||||
|
|
||||||
def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):
|
|
||||||
global args
|
|
||||||
|
|
||||||
args.listen = listen
|
|
||||||
args.share = share
|
|
||||||
args.check_for_updates = check_for_updates
|
|
||||||
args.models_from_local_only = models_from_local_only
|
|
||||||
args.low_vram = low_vram
|
|
||||||
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
|
||||||
args.device_override = device_override
|
|
||||||
args.sample_batch_size = sample_batch_size
|
|
||||||
args.embed_output_metadata = embed_output_metadata
|
|
||||||
args.latents_lean_and_mean = latents_lean_and_mean
|
|
||||||
args.voice_fixer = voice_fixer
|
|
||||||
args.voice_fixer_use_cuda = voice_fixer_use_cuda
|
|
||||||
args.concurrency_count = concurrency_count
|
|
||||||
args.output_sample_rate = output_sample_rate
|
|
||||||
args.output_volume = output_volume
|
|
||||||
|
|
||||||
settings = {
|
|
||||||
'listen': None if args.listen else args.listen,
|
|
||||||
'share': args.share,
|
|
||||||
'low-vram':args.low_vram,
|
|
||||||
'check-for-updates':args.check_for_updates,
|
|
||||||
'models-from-local-only':args.models_from_local_only,
|
|
||||||
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
|
||||||
'device-override': args.device_override,
|
|
||||||
'sample-batch-size': args.sample_batch_size,
|
|
||||||
'embed-output-metadata': args.embed_output_metadata,
|
|
||||||
'latents-lean-and-mean': args.latents_lean_and_mean,
|
|
||||||
'voice-fixer': args.voice_fixer,
|
|
||||||
'voice-fixer-use-cuda': args.voice_fixer_use_cuda,
|
|
||||||
'concurrency-count': args.concurrency_count,
|
|
||||||
'output-sample-rate': args.output_sample_rate,
|
|
||||||
'output-volume': args.output_volume,
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(settings, indent='\t') )
|
|
||||||
|
|
||||||
def setup_gradio():
|
def setup_gradio():
|
||||||
global args
|
global args
|
||||||
global ui
|
global ui
|
||||||
|
@ -528,6 +369,29 @@ def setup_gradio():
|
||||||
import_voice_name,
|
import_voice_name,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
with gr.Tab("Training"):
|
||||||
|
with gr.Tab("Configuration"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
training_settings = [
|
||||||
|
gr.Slider(label="Batch Size", value=128),
|
||||||
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||||
|
gr.Number(label="Print Frequency", value=50),
|
||||||
|
gr.Number(label="Save Frequency", value=50),
|
||||||
|
]
|
||||||
|
with gr.Column():
|
||||||
|
training_settings = training_settings + [
|
||||||
|
gr.Textbox(label="Training Name", placeholder="finetune"),
|
||||||
|
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
||||||
|
gr.Textbox(label="Dataset Path", placeholder="./experiments/finetune/train.txt"),
|
||||||
|
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
||||||
|
gr.Textbox(label="Validation Path", placeholder="./experiments/finetune/val.txt"),
|
||||||
|
]
|
||||||
|
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||||
|
save_yaml_button.click(save_training_settings,
|
||||||
|
inputs=training_settings,
|
||||||
|
outputs=None
|
||||||
|
)
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
exec_inputs = []
|
exec_inputs = []
|
||||||
|
@ -586,71 +450,15 @@ def setup_gradio():
|
||||||
]
|
]
|
||||||
|
|
||||||
# YUCK
|
# YUCK
|
||||||
def run_generation(
|
def update_voices():
|
||||||
text,
|
|
||||||
delimiter,
|
|
||||||
emotion,
|
|
||||||
prompt,
|
|
||||||
voice,
|
|
||||||
mic_audio,
|
|
||||||
voice_latents_chunks,
|
|
||||||
seed,
|
|
||||||
candidates,
|
|
||||||
num_autoregressive_samples,
|
|
||||||
diffusion_iterations,
|
|
||||||
temperature,
|
|
||||||
diffusion_sampler,
|
|
||||||
breathing_room,
|
|
||||||
cvvp_weight,
|
|
||||||
top_p,
|
|
||||||
diffusion_temperature,
|
|
||||||
length_penalty,
|
|
||||||
repetition_penalty,
|
|
||||||
cond_free_k,
|
|
||||||
experimental_checkboxes,
|
|
||||||
progress=gr.Progress(track_tqdm=True)
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
sample, outputs, stats = generate(
|
|
||||||
text,
|
|
||||||
delimiter,
|
|
||||||
emotion,
|
|
||||||
prompt,
|
|
||||||
voice,
|
|
||||||
mic_audio,
|
|
||||||
voice_latents_chunks,
|
|
||||||
seed,
|
|
||||||
candidates,
|
|
||||||
num_autoregressive_samples,
|
|
||||||
diffusion_iterations,
|
|
||||||
temperature,
|
|
||||||
diffusion_sampler,
|
|
||||||
breathing_room,
|
|
||||||
cvvp_weight,
|
|
||||||
top_p,
|
|
||||||
diffusion_temperature,
|
|
||||||
length_penalty,
|
|
||||||
repetition_penalty,
|
|
||||||
cond_free_k,
|
|
||||||
experimental_checkboxes,
|
|
||||||
progress
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
message = str(e)
|
|
||||||
if message == "Kill signal detected":
|
|
||||||
reload_tts()
|
|
||||||
|
|
||||||
raise gr.Error(message)
|
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
outputs[0],
|
gr.Dropdown.update(choices=get_voice_list()),
|
||||||
gr.update(value=sample, visible=sample is not None),
|
gr.Dropdown.update(choices=get_voice_list("./results/")),
|
||||||
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
|
|
||||||
gr.update(visible=len(outputs) > 1),
|
|
||||||
gr.update(value=stats, visible=True),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def history_copy_settings( voice, file ):
|
||||||
|
return import_generate_settings( f"./results/{voice}/{file}" )
|
||||||
|
|
||||||
refresh_voices.click(update_voices,
|
refresh_voices.click(update_voices,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -681,21 +489,12 @@ def setup_gradio():
|
||||||
outputs=input_settings
|
outputs=input_settings
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset_generation_settings():
|
|
||||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps({}, indent='\t') )
|
|
||||||
return import_generate_settings()
|
|
||||||
|
|
||||||
reset_generation_settings_button.click(
|
reset_generation_settings_button.click(
|
||||||
fn=reset_generation_settings,
|
fn=reset_generation_settings,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=input_settings
|
outputs=input_settings
|
||||||
)
|
)
|
||||||
|
|
||||||
def history_copy_settings( voice, file ):
|
|
||||||
settings = import_generate_settings( f"./results/{voice}/{file}" )
|
|
||||||
return settings
|
|
||||||
|
|
||||||
history_copy_settings_button.click(history_copy_settings,
|
history_copy_settings_button.click(history_copy_settings,
|
||||||
inputs=[
|
inputs=[
|
||||||
history_voices,
|
history_voices,
|
||||||
|
|
144
training/.template.yaml
Executable file
144
training/.template.yaml
Executable file
|
@ -0,0 +1,144 @@
|
||||||
|
name: ${name}
|
||||||
|
model: extensibletrainer
|
||||||
|
scale: 1
|
||||||
|
gpu_ids: [0] # <-- unless you have multiple gpus, use this
|
||||||
|
start_step: -1
|
||||||
|
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
||||||
|
fp16: false # might want to check this out
|
||||||
|
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
|
||||||
|
use_tb_logger: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: ${dataset_name}
|
||||||
|
n_workers: 8 # idk what this does
|
||||||
|
batch_size: ${batch_size} # This leads to ~16GB of vram usage on my 3090.
|
||||||
|
mode: paired_voice_audio
|
||||||
|
path: ${dataset_path}
|
||||||
|
fetcher_mode: ['lj'] # CHANGEME if your dataset isn't in LJSpeech format
|
||||||
|
phase: train
|
||||||
|
max_wav_length: 255995
|
||||||
|
max_text_length: 200
|
||||||
|
sample_rate: 22050
|
||||||
|
load_conditioning: True
|
||||||
|
num_conditioning_candidates: 2
|
||||||
|
conditioning_length: 44000
|
||||||
|
use_bpe_tokenizer: True
|
||||||
|
load_aligned_codes: False
|
||||||
|
val:
|
||||||
|
name: ${validation_name}
|
||||||
|
n_workers: 1
|
||||||
|
batch_size: 32 # this could be higher probably
|
||||||
|
mode: paired_voice_audio
|
||||||
|
path: ${validation_path}
|
||||||
|
fetcher_mode: ['lj']
|
||||||
|
phase: val # might be broken idk
|
||||||
|
max_wav_length: 255995
|
||||||
|
max_text_length: 200
|
||||||
|
sample_rate: 22050
|
||||||
|
load_conditioning: True
|
||||||
|
num_conditioning_candidates: 2
|
||||||
|
conditioning_length: 44000
|
||||||
|
use_bpe_tokenizer: True
|
||||||
|
load_aligned_codes: False
|
||||||
|
|
||||||
|
steps:
|
||||||
|
gpt_train:
|
||||||
|
training: gpt
|
||||||
|
loss_log_buffer: 500 # no idea what this does
|
||||||
|
|
||||||
|
# Generally follows the recipe from the DALLE paper.
|
||||||
|
optimizer: adamw # this should be adamw_zero if you're using distributed training
|
||||||
|
optimizer_params:
|
||||||
|
lr: !!float ${learning_rate} # CHANGEME: this was originally 1e-4; I reduced it to 1e-5 because it's fine-tuning, but **you should experiment with this value**
|
||||||
|
weight_decay: !!float 1e-2
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.96
|
||||||
|
clip_grad_eps: 4
|
||||||
|
|
||||||
|
injectors: # TODO: replace this entire sequence with the GptVoiceLatentInjector
|
||||||
|
paired_to_mel:
|
||||||
|
type: torch_mel_spectrogram
|
||||||
|
mel_norm_file: ./experiments/clips_mel_norms.pth
|
||||||
|
in: wav
|
||||||
|
out: paired_mel
|
||||||
|
paired_cond_to_mel:
|
||||||
|
type: for_each
|
||||||
|
subtype: torch_mel_spectrogram
|
||||||
|
mel_norm_file: ./experiments/clips_mel_norms.pth
|
||||||
|
in: conditioning
|
||||||
|
out: paired_conditioning_mel
|
||||||
|
to_codes:
|
||||||
|
type: discrete_token
|
||||||
|
in: paired_mel
|
||||||
|
out: paired_mel_codes
|
||||||
|
dvae_config: "./experiments/train_diffusion_vocoder_22k_level.yml" # EXTREMELY IMPORTANT
|
||||||
|
paired_fwd_text:
|
||||||
|
type: generator
|
||||||
|
generator: gpt
|
||||||
|
in: [paired_conditioning_mel, padded_text, text_lengths, paired_mel_codes, wav_lengths]
|
||||||
|
out: [loss_text_ce, loss_mel_ce, logits]
|
||||||
|
losses:
|
||||||
|
text_ce:
|
||||||
|
type: direct
|
||||||
|
weight: .01
|
||||||
|
key: loss_text_ce
|
||||||
|
mel_ce:
|
||||||
|
type: direct
|
||||||
|
weight: 1
|
||||||
|
key: loss_mel_ce
|
||||||
|
|
||||||
|
networks:
|
||||||
|
gpt:
|
||||||
|
type: generator
|
||||||
|
which_model_G: unified_voice2 # none of the unified_voice*.py files actually match the tortoise inference code... 4 and 3 have "alignment_head" (wtf is that?), 2 lacks the types=1 parameter.
|
||||||
|
kwargs:
|
||||||
|
layers: 30 # WAS 8
|
||||||
|
model_dim: 1024 # WAS 512
|
||||||
|
heads: 16 # WAS 8
|
||||||
|
max_text_tokens: 402 # WAS 120
|
||||||
|
max_mel_tokens: 604 # WAS 250
|
||||||
|
max_conditioning_inputs: 2 # WAS 1
|
||||||
|
mel_length_compression: 1024
|
||||||
|
number_text_tokens: 256 # supposed to be 255 for newer unified_voice files
|
||||||
|
number_mel_codes: 8194
|
||||||
|
start_mel_token: 8192
|
||||||
|
stop_mel_token: 8193
|
||||||
|
start_text_token: 255
|
||||||
|
train_solo_embeddings: False # missing in uv3/4
|
||||||
|
use_mel_codes_as_input: True # ditto
|
||||||
|
checkpointing: True
|
||||||
|
#types: 1 # this is MISSING, but in my analysis 1 is equivalent to not having it.
|
||||||
|
#only_alignment_head: False # uv3/4
|
||||||
|
|
||||||
|
path:
|
||||||
|
pretrain_model_gpt: './experiments/autoregressive.pth' # CHANGEME: copy this from tortoise cache
|
||||||
|
strict_load: true
|
||||||
|
#resume_state: ./experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||||
|
|
||||||
|
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
|
||||||
|
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
|
||||||
|
niter: 50000
|
||||||
|
warmup_iter: -1
|
||||||
|
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||||
|
val_freq: 500
|
||||||
|
|
||||||
|
default_lr_scheme: MultiStepLR
|
||||||
|
gen_lr_steps: [500, 1000, 1400, 1800] #[50000, 100000, 140000, 180000]
|
||||||
|
lr_gamma: 0.5
|
||||||
|
|
||||||
|
eval:
|
||||||
|
output_state: gen
|
||||||
|
injectors:
|
||||||
|
gen_inj_eval:
|
||||||
|
type: generator
|
||||||
|
generator: generator
|
||||||
|
in: hq
|
||||||
|
out: [gen, codebook_commitment_loss]
|
||||||
|
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow
|
||||||
|
visuals: [gen, mel]
|
||||||
|
visual_debug_rate: 500
|
||||||
|
is_mel_spectrogram: true
|
Loading…
Reference in New Issue
Block a user