diff --git a/vall_e/data.py b/vall_e/data.py index 0c369dc..cd16a70 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1230,7 +1230,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--action", type=str) parser.add_argument("--tasks", type=str) - args = parser.parse_args() + args, unknown = parser.parse_known_args() task = args.action diff --git a/vall_e/emb/g2p.py b/vall_e/emb/g2p.py index f1308b0..8f7738c 100755 --- a/vall_e/emb/g2p.py +++ b/vall_e/emb/g2p.py @@ -58,25 +58,4 @@ def encode(text: str, language="en-us", backend="auto") -> list[str]: tokenized = tokenized.replace( f' {merge}', merge ) return tokenized.split(" ") - """ - - -@torch.no_grad() -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("folder", type=Path) - parser.add_argument("--suffix", type=str, default=".txt") - args = parser.parse_args() - - paths = list(args.folder.rglob(f"*{args.suffix}")) - - for path in tqdm(paths): - phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt") - if phone_path.exists(): - continue - phones = encode(open(path, "r", encoding="utf-8").read()) - open(phone_path, "w", encoding="utf-8").write(" ".join(phones)) - - -if __name__ == "__main__": - main() + """ \ No newline at end of file diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index bc58802..77e0375 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -377,26 +377,4 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ): decoded[i] = decoded[i] * scale[i] combined = sum(decoded) / len(decoded) - return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t() - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("folder", type=Path) - parser.add_argument("--suffix", default=".wav") - parser.add_argument("--device", default="cuda") - parser.add_argument("--backend", default="encodec") - args = parser.parse_args() - - device = args.device - paths = [*args.folder.rglob(f"*{args.suffix}")] - - for path in tqdm(paths): - out_path = _replace_file_extension(path, ".qnt.pt") - if out_path.exists(): - continue - qnt = encode_from_file(path, device=device) - torch.save(qnt.cpu(), out_path) - - -if __name__ == "__main__": - main() + return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t() \ No newline at end of file diff --git a/vall_e/export.py b/vall_e/export.py index d86c95d..3abbefd 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -63,7 +63,7 @@ def main(): parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--module-only", action='store_true') parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style - args = parser.parse_args() + args, unknown = parser.parse_known_args() if args.module_only: cfg.trainer.load_module_only = True diff --git a/vall_e/plot.py b/vall_e/plot.py index 06f0ff1..4eef13f 100644 --- a/vall_e/plot.py +++ b/vall_e/plot.py @@ -104,7 +104,7 @@ if __name__ == "__main__": parser.add_argument("--filename", default="log.txt") parser.add_argument("--group-level", default=1) - args = parser.parse_args() + args, unknown = parser.parse_known_args() path = cfg.rel_path / "logs" paths = path.rglob(f"./*/{args.filename}") diff --git a/vall_e/train.py b/vall_e/train.py index 99e1a44..6958066 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -6,6 +6,7 @@ from .emb import qnt from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .data import fold_inputs, unfold_outputs +from .utils.distributed import is_global_leader import auraloss import json @@ -205,7 +206,7 @@ def train(): # create log folder setup_logging(cfg.log_dir) # copy config yaml to backup - if cfg.yaml_path is not None: + if cfg.yaml_path is not None and is_global_leader(): shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) train_dl, subtrain_dl, val_dl = create_train_val_dataloader()