This commit is contained in:
mrq 2024-06-09 11:39:43 -05:00
parent 132a02c48b
commit 234f9efc6e
6 changed files with 7 additions and 49 deletions

View File

@ -1230,7 +1230,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser("Save trained model to path.") parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--action", type=str) parser.add_argument("--action", type=str)
parser.add_argument("--tasks", type=str) parser.add_argument("--tasks", type=str)
args = parser.parse_args() args, unknown = parser.parse_known_args()
task = args.action task = args.action

View File

@ -59,24 +59,3 @@ def encode(text: str, language="en-us", backend="auto") -> list[str]:
return tokenized.split(" ") return tokenized.split(" ")
""" """
@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", type=str, default=".txt")
args = parser.parse_args()
paths = list(args.folder.rglob(f"*{args.suffix}"))
for path in tqdm(paths):
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
if phone_path.exists():
continue
phones = encode(open(path, "r", encoding="utf-8").read())
open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
if __name__ == "__main__":
main()

View File

@ -378,25 +378,3 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
combined = sum(decoded) / len(decoded) combined = sum(decoded) / len(decoded)
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t() return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", default=".wav")
parser.add_argument("--device", default="cuda")
parser.add_argument("--backend", default="encodec")
args = parser.parse_args()
device = args.device
paths = [*args.folder.rglob(f"*{args.suffix}")]
for path in tqdm(paths):
out_path = _replace_file_extension(path, ".qnt.pt")
if out_path.exists():
continue
qnt = encode_from_file(path, device=device)
torch.save(qnt.cpu(), out_path)
if __name__ == "__main__":
main()

View File

@ -63,7 +63,7 @@ def main():
parser = argparse.ArgumentParser("Save trained model to path.") parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true') parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
args = parser.parse_args() args, unknown = parser.parse_known_args()
if args.module_only: if args.module_only:
cfg.trainer.load_module_only = True cfg.trainer.load_module_only = True

View File

@ -104,7 +104,7 @@ if __name__ == "__main__":
parser.add_argument("--filename", default="log.txt") parser.add_argument("--filename", default="log.txt")
parser.add_argument("--group-level", default=1) parser.add_argument("--group-level", default=1)
args = parser.parse_args() args, unknown = parser.parse_known_args()
path = cfg.rel_path / "logs" path = cfg.rel_path / "logs"
paths = path.rglob(f"./*/{args.filename}") paths = path.rglob(f"./*/{args.filename}")

View File

@ -6,6 +6,7 @@ from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .data import fold_inputs, unfold_outputs from .data import fold_inputs, unfold_outputs
from .utils.distributed import is_global_leader
import auraloss import auraloss
import json import json
@ -205,7 +206,7 @@ def train():
# create log folder # create log folder
setup_logging(cfg.log_dir) setup_logging(cfg.log_dir)
# copy config yaml to backup # copy config yaml to backup
if cfg.yaml_path is not None: if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()