ugh
This commit is contained in:
parent
132a02c48b
commit
234f9efc6e
|
@ -1230,7 +1230,7 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--action", type=str)
|
||||
parser.add_argument("--tasks", type=str)
|
||||
args = parser.parse_args()
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
task = args.action
|
||||
|
||||
|
|
|
@ -58,25 +58,4 @@ def encode(text: str, language="en-us", backend="auto") -> list[str]:
|
|||
tokenized = tokenized.replace( f' {merge}', merge )
|
||||
|
||||
return tokenized.split(" ")
|
||||
"""
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", type=Path)
|
||||
parser.add_argument("--suffix", type=str, default=".txt")
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
||||
|
||||
for path in tqdm(paths):
|
||||
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
||||
if phone_path.exists():
|
||||
continue
|
||||
phones = encode(open(path, "r", encoding="utf-8").read())
|
||||
open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""
|
|
@ -377,26 +377,4 @@ def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
|||
decoded[i] = decoded[i] * scale[i]
|
||||
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", type=Path)
|
||||
parser.add_argument("--suffix", default=".wav")
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--backend", default="encodec")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = args.device
|
||||
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
||||
|
||||
for path in tqdm(paths):
|
||||
out_path = _replace_file_extension(path, ".qnt.pt")
|
||||
if out_path.exists():
|
||||
continue
|
||||
qnt = encode_from_file(path, device=device)
|
||||
torch.save(qnt.cpu(), out_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
|
@ -63,7 +63,7 @@ def main():
|
|||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
args = parser.parse_args()
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if args.module_only:
|
||||
cfg.trainer.load_module_only = True
|
||||
|
|
|
@ -104,7 +104,7 @@ if __name__ == "__main__":
|
|||
|
||||
parser.add_argument("--filename", default="log.txt")
|
||||
parser.add_argument("--group-level", default=1)
|
||||
args = parser.parse_args()
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
path = cfg.rel_path / "logs"
|
||||
paths = path.rglob(f"./*/{args.filename}")
|
||||
|
|
|
@ -6,6 +6,7 @@ from .emb import qnt
|
|||
|
||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||
from .data import fold_inputs, unfold_outputs
|
||||
from .utils.distributed import is_global_leader
|
||||
|
||||
import auraloss
|
||||
import json
|
||||
|
@ -205,7 +206,7 @@ def train():
|
|||
# create log folder
|
||||
setup_logging(cfg.log_dir)
|
||||
# copy config yaml to backup
|
||||
if cfg.yaml_path is not None:
|
||||
if cfg.yaml_path is not None and is_global_leader():
|
||||
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
|
Loading…
Reference in New Issue
Block a user