From 2e03e5ac9394001adfa83ba3b83230942d1ada5c Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 2 Aug 2023 22:57:10 -0500 Subject: [PATCH] Fixed an issue with having fairseq installed at all will brick logging --- setup.py | 5 ++-- vall_e/models/ar.py | 17 ------------- vall_e/models/retnet.py | 53 +++++++++++++++++++++++++++++++++++++++-- vall_e/train.py | 4 ++-- vall_e/utils/utils.py | 8 +++---- 5 files changed, 59 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 6a514a2..47ef3fd 100755 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ setup( "encodec>=0.1.1", "phonemizer>=2.1.0", "matplotlib>=3.6.0", - "numpy>=1.23.3", + "numpy==1.23.0", "omegaconf==2.0.6", "tqdm>=4.64.1", "humanize>=4.4.0", @@ -58,8 +58,7 @@ setup( "auraloss[all]", "vocos", "h5py", - "git+https://github.com/microsoft/torchscale", - "fairseq", + "torchscale @ git+https://github.com/microsoft/torchscale", ], url="https://git.ecker.tech/mrq/vall-e", ) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 4213d0b..fabcd7b 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -66,24 +66,7 @@ class AR(Base): shift_targ_list=True, return_all_resp=False, ) - else: - return self._generate( - text_list, - proms_list, - max_steps, - sampling_temperature, - - naive=naive, - ) - def _generate( - self, - text_list: list[Tensor], - proms_list: list[Tensor], - max_steps: int, - sampling_temperature: float, - naive: bool = True, - ): device = text_list[0].device resp_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list diff --git a/vall_e/models/retnet.py b/vall_e/models/retnet.py index 6938e4b..73e8fe9 100755 --- a/vall_e/models/retnet.py +++ b/vall_e/models/retnet.py @@ -1,5 +1,54 @@ -from fairseq.models import FairseqIncrementalDecoder -from fairseq.incremental_decoding_utils import with_incremental_state +""" +# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py +# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this +""" + +import uuid +from typing import Dict, Optional + +from torch import Tensor + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls + from torchscale.architecture.config import RetNetConfig from torchscale.architecture.retnet import RetNetDecoder as Decoder diff --git a/vall_e/train.py b/vall_e/train.py index a42f135..9bf8686 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -83,13 +83,13 @@ def load_engines(): return trainer.load_engines(engines, cfg) def main(): + setup_logging(cfg.log_dir) + #dist.init_distributed(dist_backend=get_accelerator().communication_backend_name()) if not deepspeed._initialized_dist: deepspeed._initialized_dist = True deepspeed.init_distributed() - setup_logging(cfg.log_dir) - train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def train_feeder(engines, batch, name): diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 30d108a..988f595 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -85,7 +85,7 @@ def load_state_dict_non_strict(model, state_dict, logger=None): model.load_state_dict(state_dict, strict=False) class TqdmLoggingHandler(logging.Handler): - def __init__(self, level=logging.NOTSET): + def __init__(self, level=logging.INFO): super().__init__(level) def emit(self, record): @@ -93,8 +93,8 @@ class TqdmLoggingHandler(logging.Handler): msg = self.format(record) tqdm.write(msg) self.flush() - except Exception: - self.handleError(record) + except Exception as e: + self.handleError(record) @global_leader_only def setup_logging(log_dir: str | Path | None = "log", log_level="info"): @@ -116,13 +116,13 @@ def setup_logging(log_dir: str | Path | None = "log", log_level="info"): file_handler.setLevel(logging.DEBUG) handlers.append(file_handler) + logging.basicConfig( level=logging.getLevelName(log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s", handlers=handlers, ) - @overload def tree_map(fn: Callable, x: list[T]) -> list[T]: ...