ugh
This commit is contained in:
parent
132a02c48b
commit
234f9efc6e
|
@ -1230,7 +1230,7 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--action", type=str)
|
parser.add_argument("--action", type=str)
|
||||||
parser.add_argument("--tasks", type=str)
|
parser.add_argument("--tasks", type=str)
|
||||||
args = parser.parse_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
task = args.action
|
task = args.action
|
||||||
|
|
||||||
|
|
|
@ -58,25 +58,4 @@ def encode(text: str, language="en-us", backend="auto") -> list[str]:
|
||||||
tokenized = tokenized.replace( f' {merge}', merge )
|
tokenized = tokenized.replace( f' {merge}', merge )
|
||||||
|
|
||||||
return tokenized.split(" ")
|
return tokenized.split(" ")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("folder", type=Path)
|
|
||||||
parser.add_argument("--suffix", type=str, default=".txt")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
|
||||||
|
|
||||||
for path in tqdm(paths):
|
|
||||||
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
|
||||||
if phone_path.exists():
|
|
||||||
continue
|
|
||||||
phones = encode(open(path, "r", encoding="utf-8").read())
|
|
||||||
open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -377,26 +377,4 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
||||||
decoded[i] = decoded[i] * scale[i]
|
decoded[i] = decoded[i] * scale[i]
|
||||||
|
|
||||||
combined = sum(decoded) / len(decoded)
|
combined = sum(decoded) / len(decoded)
|
||||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("folder", type=Path)
|
|
||||||
parser.add_argument("--suffix", default=".wav")
|
|
||||||
parser.add_argument("--device", default="cuda")
|
|
||||||
parser.add_argument("--backend", default="encodec")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
device = args.device
|
|
||||||
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
|
||||||
|
|
||||||
for path in tqdm(paths):
|
|
||||||
out_path = _replace_file_extension(path, ".qnt.pt")
|
|
||||||
if out_path.exists():
|
|
||||||
continue
|
|
||||||
qnt = encode_from_file(path, device=device)
|
|
||||||
torch.save(qnt.cpu(), out_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -63,7 +63,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--module-only", action='store_true')
|
parser.add_argument("--module-only", action='store_true')
|
||||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||||
args = parser.parse_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
if args.module_only:
|
if args.module_only:
|
||||||
cfg.trainer.load_module_only = True
|
cfg.trainer.load_module_only = True
|
||||||
|
|
|
@ -104,7 +104,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
parser.add_argument("--filename", default="log.txt")
|
parser.add_argument("--filename", default="log.txt")
|
||||||
parser.add_argument("--group-level", default=1)
|
parser.add_argument("--group-level", default=1)
|
||||||
args = parser.parse_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
path = cfg.rel_path / "logs"
|
path = cfg.rel_path / "logs"
|
||||||
paths = path.rglob(f"./*/{args.filename}")
|
paths = path.rglob(f"./*/{args.filename}")
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .emb import qnt
|
||||||
|
|
||||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||||
from .data import fold_inputs, unfold_outputs
|
from .data import fold_inputs, unfold_outputs
|
||||||
|
from .utils.distributed import is_global_leader
|
||||||
|
|
||||||
import auraloss
|
import auraloss
|
||||||
import json
|
import json
|
||||||
|
@ -205,7 +206,7 @@ def train():
|
||||||
# create log folder
|
# create log folder
|
||||||
setup_logging(cfg.log_dir)
|
setup_logging(cfg.log_dir)
|
||||||
# copy config yaml to backup
|
# copy config yaml to backup
|
||||||
if cfg.yaml_path is not None:
|
if cfg.yaml_path is not None and is_global_leader():
|
||||||
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user