Fixed an issue with having fairseq installed at all will brick logging
This commit is contained in:
parent
f6597e2dfe
commit
2e03e5ac93
5
setup.py
5
setup.py
|
@ -45,7 +45,7 @@ setup(
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
"phonemizer>=2.1.0",
|
"phonemizer>=2.1.0",
|
||||||
"matplotlib>=3.6.0",
|
"matplotlib>=3.6.0",
|
||||||
"numpy>=1.23.3",
|
"numpy==1.23.0",
|
||||||
"omegaconf==2.0.6",
|
"omegaconf==2.0.6",
|
||||||
"tqdm>=4.64.1",
|
"tqdm>=4.64.1",
|
||||||
"humanize>=4.4.0",
|
"humanize>=4.4.0",
|
||||||
|
@ -58,8 +58,7 @@ setup(
|
||||||
"auraloss[all]",
|
"auraloss[all]",
|
||||||
"vocos",
|
"vocos",
|
||||||
"h5py",
|
"h5py",
|
||||||
"git+https://github.com/microsoft/torchscale",
|
"torchscale @ git+https://github.com/microsoft/torchscale",
|
||||||
"fairseq",
|
|
||||||
],
|
],
|
||||||
url="https://git.ecker.tech/mrq/vall-e",
|
url="https://git.ecker.tech/mrq/vall-e",
|
||||||
)
|
)
|
||||||
|
|
|
@ -66,24 +66,7 @@ class AR(Base):
|
||||||
shift_targ_list=True,
|
shift_targ_list=True,
|
||||||
return_all_resp=False,
|
return_all_resp=False,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return self._generate(
|
|
||||||
text_list,
|
|
||||||
proms_list,
|
|
||||||
max_steps,
|
|
||||||
sampling_temperature,
|
|
||||||
|
|
||||||
naive=naive,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
text_list: list[Tensor],
|
|
||||||
proms_list: list[Tensor],
|
|
||||||
max_steps: int,
|
|
||||||
sampling_temperature: float,
|
|
||||||
naive: bool = True,
|
|
||||||
):
|
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
resp_list: list[Tensor] = [
|
resp_list: list[Tensor] = [
|
||||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||||
|
|
|
@ -1,5 +1,54 @@
|
||||||
from fairseq.models import FairseqIncrementalDecoder
|
"""
|
||||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py
|
||||||
|
# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
class FairseqIncrementalState(object):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.init_incremental_state()
|
||||||
|
|
||||||
|
def init_incremental_state(self):
|
||||||
|
self._incremental_state_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def _get_full_incremental_state_key(self, key: str) -> str:
|
||||||
|
return "{}.{}".format(self._incremental_state_id, key)
|
||||||
|
|
||||||
|
def get_incremental_state(
|
||||||
|
self,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||||
|
key: str,
|
||||||
|
) -> Optional[Dict[str, Optional[Tensor]]]:
|
||||||
|
"""Helper for getting incremental state for an nn.Module."""
|
||||||
|
full_key = self._get_full_incremental_state_key(key)
|
||||||
|
if incremental_state is None or full_key not in incremental_state:
|
||||||
|
return None
|
||||||
|
return incremental_state[full_key]
|
||||||
|
|
||||||
|
def set_incremental_state(
|
||||||
|
self,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||||
|
key: str,
|
||||||
|
value: Dict[str, Optional[Tensor]],
|
||||||
|
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
||||||
|
"""Helper for setting incremental state for an nn.Module."""
|
||||||
|
if incremental_state is not None:
|
||||||
|
full_key = self._get_full_incremental_state_key(key)
|
||||||
|
incremental_state[full_key] = value
|
||||||
|
return incremental_state
|
||||||
|
|
||||||
|
|
||||||
|
def with_incremental_state(cls):
|
||||||
|
cls.__bases__ = (FairseqIncrementalState,) + tuple(
|
||||||
|
b for b in cls.__bases__ if b != FairseqIncrementalState
|
||||||
|
)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
from torchscale.architecture.config import RetNetConfig
|
from torchscale.architecture.config import RetNetConfig
|
||||||
from torchscale.architecture.retnet import RetNetDecoder as Decoder
|
from torchscale.architecture.retnet import RetNetDecoder as Decoder
|
||||||
|
|
|
@ -83,13 +83,13 @@ def load_engines():
|
||||||
return trainer.load_engines(engines, cfg)
|
return trainer.load_engines(engines, cfg)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
setup_logging(cfg.log_dir)
|
||||||
|
|
||||||
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
||||||
if not deepspeed._initialized_dist:
|
if not deepspeed._initialized_dist:
|
||||||
deepspeed._initialized_dist = True
|
deepspeed._initialized_dist = True
|
||||||
deepspeed.init_distributed()
|
deepspeed.init_distributed()
|
||||||
|
|
||||||
setup_logging(cfg.log_dir)
|
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
def train_feeder(engines, batch, name):
|
def train_feeder(engines, batch, name):
|
||||||
|
|
|
@ -85,7 +85,7 @@ def load_state_dict_non_strict(model, state_dict, logger=None):
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
class TqdmLoggingHandler(logging.Handler):
|
class TqdmLoggingHandler(logging.Handler):
|
||||||
def __init__(self, level=logging.NOTSET):
|
def __init__(self, level=logging.INFO):
|
||||||
super().__init__(level)
|
super().__init__(level)
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
|
@ -93,8 +93,8 @@ class TqdmLoggingHandler(logging.Handler):
|
||||||
msg = self.format(record)
|
msg = self.format(record)
|
||||||
tqdm.write(msg)
|
tqdm.write(msg)
|
||||||
self.flush()
|
self.flush()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
self.handleError(record)
|
self.handleError(record)
|
||||||
|
|
||||||
@global_leader_only
|
@global_leader_only
|
||||||
def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
|
def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
|
||||||
|
@ -116,13 +116,13 @@ def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
handlers.append(file_handler)
|
handlers.append(file_handler)
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.getLevelName(log_level.upper()),
|
level=logging.getLevelName(log_level.upper()),
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s",
|
||||||
handlers=handlers,
|
handlers=handlers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def tree_map(fn: Callable, x: list[T]) -> list[T]:
|
def tree_map(fn: Callable, x: list[T]) -> list[T]:
|
||||||
...
|
...
|
||||||
|
|
Loading…
Reference in New Issue
Block a user