forked from mrq/ai-voice-cloning
Compare commits
34 Commits
350d2d5a95
...
5f80ee9b38
Author | SHA1 | Date | |
---|---|---|---|
5f80ee9b38 | |||
|
29c270d1cc | ||
7fc8f4c45a | |||
|
7110b878b7 | ||
13b65d8775 | |||
b72f2216bf | |||
690947ad36 | |||
6f0f148782 | |||
578a5bcadd | |||
b4dc103931 | |||
a657623cbc | |||
533b73e083 | |||
f5fab33e9c | |||
|
4aa240d48a | ||
00b173857d | |||
dc46fdc7d0 | |||
29290f574e | |||
0a5483e57a | |||
|
e613299304 | ||
ce24ba41e2 | |||
|
5f4215b3ef | ||
5d73d9e71c | |||
9abcb0f193 | |||
|
fb1cfd059f | ||
1ec3344999 | |||
a902913780 | |||
2060b6f21c | |||
eeddd4cb6b | |||
72a38ff2fc | |||
|
91a0c495ff | ||
2626364c40 | |||
ac645e0a20 | |||
e2a6dc1c0a | |||
|
a325496661 |
|
@ -1,8 +1,8 @@
|
|||
# AI Voice Cloning
|
||||
|
||||
This [repo](https://git.ecker.tech/mrq/ai-voice-cloning)/[rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows/Linux, as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts).
|
||||
> **Note** This project has been in dire need of being rewritten from the ground up for some time. Apologies for any crust from my rather spaghetti code.
|
||||
|
||||
Similar to my own findings for Stable Diffusion image generation, this rentry may appear a little disheveled as I note my new findings with TorToiSe. Please keep this in mind if the guide seems to shift a bit or sound confusing.
|
||||
This [repo](https://git.ecker.tech/mrq/ai-voice-cloning)/[rentry](https://rentry.org/AI-Voice-Cloning/) aims to serve as both a foolproof guide for setting up AI voice cloning tools for legitimate, local use on Windows/Linux, as well as a stepping stone for anons that genuinely want to play around with [TorToiSe](https://github.com/neonbjb/tortoise-tts).
|
||||
|
||||
>\>Ugh... why bother when I can just abuse 11.AI?
|
||||
|
||||
|
|
|
@ -1,13 +1,106 @@
|
|||
data_dirs: [./training/${voice}/valle/]
|
||||
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
|
||||
dataset:
|
||||
training: [
|
||||
"./training/${voice}/valle/",
|
||||
]
|
||||
noise: [
|
||||
"./training/valle/data/Other/noise/",
|
||||
]
|
||||
|
||||
speaker_name_getter: "lambda p: p.parts[-3]" # "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
|
||||
use_hdf5: False
|
||||
hdf5_name: data.h5
|
||||
hdf5_flag: r
|
||||
validate: True
|
||||
|
||||
max_phones: 72
|
||||
workers: 4
|
||||
cache: False
|
||||
|
||||
models: '${models}'
|
||||
batch_size: ${batch_size}
|
||||
gradient_accumulation_steps: ${gradient_accumulation_size}
|
||||
eval_batch_size: ${batch_size}
|
||||
phones_range: [4, 64]
|
||||
duration_range: [1.0, 8.0]
|
||||
|
||||
max_iter: ${iterations}
|
||||
save_ckpt_every: ${save_rate}
|
||||
eval_every: ${validation_rate}
|
||||
random_utterance: 1.0
|
||||
max_prompts: 3
|
||||
prompt_duration: 3.0
|
||||
|
||||
sample_type: path
|
||||
|
||||
tasks_list: ["tts"] # ["tts", "ns", "sr", "tse", "cse", "nse", "tts"]
|
||||
|
||||
models:
|
||||
_max_levels: 8
|
||||
_models:
|
||||
- name: "ar"
|
||||
size: "full"
|
||||
resp_levels: 1
|
||||
prom_levels: 2
|
||||
tasks: 8
|
||||
arch_type: "retnet"
|
||||
|
||||
- name: "nar"
|
||||
size: "full"
|
||||
resp_levels: 3
|
||||
prom_levels: 4
|
||||
tasks: 8
|
||||
arch_type: "retnet"
|
||||
|
||||
|
||||
hyperparameters:
|
||||
batch_size: ${batch_size}
|
||||
gradient_accumulation_steps: ${gradient_accumulation_size}
|
||||
gradient_clipping: 100
|
||||
|
||||
optimizer: AdamW
|
||||
learning_rate: 1.0e-4
|
||||
|
||||
scheduler_type: ""
|
||||
|
||||
evaluation:
|
||||
batch_size: ${batch_size}
|
||||
frequency: ${validation_rate}
|
||||
size: 16
|
||||
|
||||
steps: 300
|
||||
ar_temperature: 0.95
|
||||
nar_temperature: 0.25
|
||||
|
||||
trainer:
|
||||
iterations: ${iterations}
|
||||
|
||||
save_tag: step
|
||||
save_on_oom: True
|
||||
save_on_quit: True
|
||||
export_on_save: True
|
||||
export_on_quit: True
|
||||
save_frequency: ${save_rate}
|
||||
|
||||
keep_last_checkpoints: 4
|
||||
|
||||
aggressive_optimizations: False
|
||||
|
||||
load_state_dict: True
|
||||
#strict_loading: False
|
||||
#load_tag: "9500"
|
||||
#load_states: False
|
||||
#restart_step_count: True
|
||||
|
||||
gc_mode: None # "global_step"
|
||||
|
||||
weight_dtype: bfloat16
|
||||
|
||||
backend: deepspeed
|
||||
deepspeed:
|
||||
zero_optimization_level: 2
|
||||
use_compression_training: True
|
||||
|
||||
inference:
|
||||
use_vocos: True
|
||||
normalize: False
|
||||
|
||||
weight_dtype: float32
|
||||
|
||||
bitsandbytes:
|
||||
enabled: False
|
||||
injects: True
|
||||
linear: True
|
||||
embedding: True
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 5ff00bf3bfa97e2c8e9f166b920273f83ac9d8f0
|
||||
Subproject commit b10c58436d6871c26485d30b203e6cfdd4167602
|
|
@ -7,4 +7,5 @@ music-tag
|
|||
voicefixer
|
||||
psutil
|
||||
phonemizer
|
||||
pydantic==1.10.11
|
||||
pydantic==1.10.11
|
||||
websockets
|
84
src/api/websocket_server.py
Normal file
84
src/api/websocket_server.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import asyncio
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from websockets.server import serve
|
||||
|
||||
from utils import generate, get_autoregressive_models, get_voice_list, args, update_autoregressive_model, update_diffusion_model, update_tokenizer
|
||||
|
||||
# this is a not so nice workaround to set values to None if their string value is "None"
|
||||
def replaceNoneStringWithNone(message):
|
||||
ignore_fields = ['text'] # list of fields which CAN have "None" as literal String value
|
||||
|
||||
for member in message:
|
||||
if message[member] == 'None' and member not in ignore_fields:
|
||||
message[member] = None
|
||||
|
||||
return message
|
||||
|
||||
|
||||
async def _handle_generate(websocket, message):
|
||||
# update args parameters which control the model settings
|
||||
if message.get('autoregressive_model'):
|
||||
update_autoregressive_model(message['autoregressive_model'])
|
||||
|
||||
if message.get('diffusion_model'):
|
||||
update_diffusion_model(message['diffusion_model'])
|
||||
|
||||
if message.get('tokenizer_json'):
|
||||
update_tokenizer(message['tokenizer_json'])
|
||||
|
||||
if message.get('sample_batch_size'):
|
||||
global args
|
||||
args.sample_batch_size = message['sample_batch_size']
|
||||
|
||||
message['result'] = generate(**message)
|
||||
await websocket.send(json.dumps(replaceNoneStringWithNone(message)))
|
||||
|
||||
|
||||
async def _handle_get_autoregressive_models(websocket, message):
|
||||
message['result'] = get_autoregressive_models()
|
||||
await websocket.send(json.dumps(replaceNoneStringWithNone(message)))
|
||||
|
||||
|
||||
async def _handle_get_voice_list(websocket, message):
|
||||
message['result'] = get_voice_list()
|
||||
await websocket.send(json.dumps(replaceNoneStringWithNone(message)))
|
||||
|
||||
|
||||
async def _handle_message(websocket, message):
|
||||
message = replaceNoneStringWithNone(message)
|
||||
|
||||
if message.get('action') and message['action'] == 'generate':
|
||||
await _handle_generate(websocket, message)
|
||||
elif message.get('action') and message['action'] == 'get_voices':
|
||||
await _handle_get_voice_list(websocket, message)
|
||||
elif message.get('action') and message['action'] == 'get_autoregressive_models':
|
||||
await _handle_get_autoregressive_models(websocket, message)
|
||||
else:
|
||||
print("websocket: undhandled message: " + message)
|
||||
|
||||
|
||||
async def _handle_connection(websocket, path):
|
||||
print("websocket: client connected")
|
||||
|
||||
async for message in websocket:
|
||||
try:
|
||||
await _handle_message(websocket, json.loads(message))
|
||||
except ValueError:
|
||||
print("websocket: malformed json received")
|
||||
|
||||
|
||||
async def _run(host: str, port: int):
|
||||
print(f"websocket: server started on ws://{host}:{port}")
|
||||
|
||||
async with serve(_handle_connection, host, port, ping_interval=None):
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
def _run_server(listen_address: str, port: int):
|
||||
asyncio.run(_run(host=listen_address, port=port))
|
||||
|
||||
|
||||
def start_websocket_server(listen_address: str, port: int):
|
||||
Thread(target=_run_server, args=[listen_address, port], daemon=True).start()
|
|
@ -11,6 +11,9 @@ os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
|||
from utils import *
|
||||
from webui import *
|
||||
|
||||
from api.websocket_server import start_websocket_server
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = setup_args()
|
||||
|
||||
|
@ -23,6 +26,9 @@ if __name__ == "__main__":
|
|||
if not args.defer_tts_load:
|
||||
tts = load_tts()
|
||||
|
||||
if args.websocket_enabled:
|
||||
start_websocket_server(args.websocket_listen_address, args.websocket_listen_port)
|
||||
|
||||
webui.block_thread()
|
||||
elif __name__ == "main":
|
||||
from fastapi import FastAPI
|
||||
|
@ -37,4 +43,5 @@ elif __name__ == "main":
|
|||
app = gr.mount_gradio_app(app, webui, path=args.listen_path)
|
||||
|
||||
if not args.defer_tts_load:
|
||||
tts = load_tts()
|
||||
tts = load_tts()
|
||||
|
||||
|
|
7777
src/utils.py
7777
src/utils.py
File diff suppressed because it is too large
Load Diff
1953
src/webui.py
1953
src/webui.py
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user