moved prints to use logger, edited readme (fused_attn doesnt seem stable for training)
This commit is contained in:
parent
d423bc03c2
commit
32287710a2
|
@ -161,7 +161,7 @@ For audio backends:
|
||||||
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
|
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
|
||||||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||||
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
||||||
* `fused_attn`: uses an implementation using `triton` (only tested on my 7900XTX / Navi3 / gfx1100)
|
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
||||||
* `transformers` Llama\*Attention implementations:
|
* `transformers` Llama\*Attention implementations:
|
||||||
* `eager`: default `LlamaAttention`
|
* `eager`: default `LlamaAttention`
|
||||||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||||
|
|
|
@ -9,6 +9,7 @@ import time
|
||||||
import argparse
|
import argparse
|
||||||
import yaml
|
import yaml
|
||||||
import random
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -163,7 +164,8 @@ class Dataset:
|
||||||
sample_order: str = "interleaved" # duration
|
sample_order: str = "interleaved" # duration
|
||||||
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||||
sample_shuffle: bool = True #
|
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
|
||||||
|
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
||||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||||
|
@ -364,6 +366,7 @@ class Model:
|
||||||
|
|
||||||
return dict(include=include, exclude=exclude)
|
return dict(include=include, exclude=exclude)
|
||||||
|
|
||||||
|
# should be renamed to Adapters
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class LoRA:
|
class LoRA:
|
||||||
name: str = "lora" # vanity name
|
name: str = "lora" # vanity name
|
||||||
|
@ -638,9 +641,6 @@ class Trainer:
|
||||||
def scale_loss(self):
|
def scale_loss(self):
|
||||||
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
||||||
return self.dtype == torch.float16
|
return self.dtype == torch.float16
|
||||||
"""
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
|
@ -670,7 +670,6 @@ class Inference:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
# should be renamed to optimizations
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Optimizations:
|
class Optimizations:
|
||||||
injects: bool = False # overwrites default torch classes (not recommended)
|
injects: bool = False # overwrites default torch classes (not recommended)
|
||||||
|
@ -755,6 +754,7 @@ class Config(BaseConfig):
|
||||||
|
|
||||||
return self.models[0] if len(self.models) > 0 else None
|
return self.models[0] if len(self.models) > 0 else None
|
||||||
|
|
||||||
|
# should be renamed to adapters
|
||||||
@property
|
@property
|
||||||
def lora(self):
|
def lora(self):
|
||||||
for i, lora in enumerate(self.loras):
|
for i, lora in enumerate(self.loras):
|
||||||
|
@ -795,7 +795,7 @@ class Config(BaseConfig):
|
||||||
try:
|
try:
|
||||||
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
|
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
|
||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
# to-do: prune unused keys
|
# to-do: prune unused keys
|
||||||
|
@ -923,7 +923,7 @@ class Config(BaseConfig):
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
cfg.tokenizer = NaiveTokenizer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
cfg.tokenizer = NaiveTokenizer()
|
||||||
print("Error while parsing tokenizer:", e)
|
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -960,6 +960,7 @@ class NaiveTokenizer:
|
||||||
# tokenize
|
# tokenize
|
||||||
return [*map(symmap.get, phones)]
|
return [*map(symmap.get, phones)]
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
cfg = Config.from_cli()
|
cfg = Config.from_cli()
|
||||||
|
|
||||||
|
@ -967,7 +968,7 @@ cfg = Config.from_cli()
|
||||||
try:
|
try:
|
||||||
cfg.format()
|
cfg.format()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while parsing config YAML:")
|
_logger.error(f"Error while parsing config YAML: {str(e)}")
|
||||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -495,7 +495,6 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
|
|
||||||
def _get_hdf5_path(path):
|
def _get_hdf5_path(path):
|
||||||
# to-do: better validation
|
# to-do: better validation
|
||||||
#print(path)
|
|
||||||
return str(path)
|
return str(path)
|
||||||
|
|
||||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||||
|
@ -1543,12 +1542,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
cfg.dataset.workers = 1
|
cfg.dataset.workers = 1
|
||||||
|
|
||||||
class LoggerOveride:
|
|
||||||
def info(self, *args):
|
|
||||||
print(*args)
|
|
||||||
|
|
||||||
_logger = LoggerOveride()
|
|
||||||
|
|
||||||
if args.action == "hdf5":
|
if args.action == "hdf5":
|
||||||
create_dataset_hdf5()
|
create_dataset_hdf5()
|
||||||
elif args.action == "list-dataset":
|
elif args.action == "list-dataset":
|
||||||
|
@ -1559,7 +1552,7 @@ if __name__ == "__main__":
|
||||||
continue
|
continue
|
||||||
dataset.append(f'{group}/{name}')
|
dataset.append(f'{group}/{name}')
|
||||||
|
|
||||||
print(json.dumps(dataset))
|
_logger.info(json.dumps(dataset))
|
||||||
elif args.action == "metadata":
|
elif args.action == "metadata":
|
||||||
create_dataset_metadata()
|
create_dataset_metadata()
|
||||||
elif args.action == "sample":
|
elif args.action == "sample":
|
||||||
|
@ -1581,17 +1574,17 @@ if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
|
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
|
||||||
try:
|
try:
|
||||||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
|
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
|
||||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||||
|
|
||||||
for k, v in samples.items():
|
for k, v in samples.items():
|
||||||
for i in range(len(v)):
|
for i in range(len(v)):
|
||||||
print(f'{k}[{i}]:', v[i])
|
_logger.info(f'{k}[{i}]: {v[i]}')
|
||||||
elif args.action == "validate":
|
elif args.action == "validate":
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
|
@ -1610,11 +1603,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
phone = phonemes[i]
|
phone = phonemes[i]
|
||||||
|
|
||||||
print( batch['text'], batch['metadata']['phonemes'] )
|
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
|
||||||
|
|
||||||
missing |= set([phone])
|
missing |= set([phone])
|
||||||
|
|
||||||
print( "Missing tokens:", missing )
|
_logger.info( f"Missing tokens: {missing}" )
|
||||||
|
|
||||||
|
|
||||||
elif args.action == "tasks":
|
elif args.action == "tasks":
|
||||||
|
@ -1628,13 +1621,13 @@ if __name__ == "__main__":
|
||||||
if task not in cfg.dataset.tasks_list:
|
if task not in cfg.dataset.tasks_list:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(text, task, cfg.model.resp_levels)
|
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
|
||||||
print( proms.shape, resps.shape )
|
_logger.info( f'{proms.shape} {resps.shape}' )
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
||||||
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||||
print( tokens )
|
_logger.info( f'{tokens}' )
|
||||||
|
|
||||||
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
||||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||||
|
|
|
@ -18,6 +18,9 @@ Will also generate samples from a provided datset, if requested.
|
||||||
import argparse
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
import random
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -117,9 +120,9 @@ def main():
|
||||||
|
|
||||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||||
|
|
||||||
print("Loading dataloader...")
|
_logger.info("Loading dataloader...")
|
||||||
dataloader = create_train_dataloader()
|
dataloader = create_train_dataloader()
|
||||||
print("Loaded dataloader.")
|
_logger.info("Loaded dataloader.")
|
||||||
|
|
||||||
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,9 @@ import argparse
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -78,7 +81,7 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
||||||
try:
|
try:
|
||||||
process_job( outpath, waveform, sample_rate, text, language )
|
process_job( outpath, waveform, sample_rate, text, language )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to quantize: {outpath}:", e)
|
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||||
if raise_exceptions:
|
if raise_exceptions:
|
||||||
raise e
|
raise e
|
||||||
continue
|
continue
|
||||||
|
@ -128,7 +131,7 @@ def process(
|
||||||
|
|
||||||
for group_name in sorted(os.listdir(f'./{input_audio}/')):
|
for group_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
|
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
|
||||||
print("Is not dir:", f'./{input_audio}/{group_name}/')
|
_logger.warning(f'Is not dir:" /{input_audio}/{group_name}/')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if group_name in ignore_groups:
|
if group_name in ignore_groups:
|
||||||
|
@ -138,7 +141,7 @@ def process(
|
||||||
|
|
||||||
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
|
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
|
||||||
if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
|
if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
|
||||||
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
|
_logger.warning(f'Is not dir: ./{input_audio}/{group_name}/{speaker_id}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if speaker_id in ignore_speakers:
|
if speaker_id in ignore_speakers:
|
||||||
|
|
|
@ -6,6 +6,9 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -203,7 +206,7 @@ try:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.inference.use_dac = False
|
cfg.inference.use_dac = False
|
||||||
print(str(e))
|
_logger.warning(str(e))
|
||||||
|
|
||||||
# uses https://github.com/facebookresearch/AudioDec/
|
# uses https://github.com/facebookresearch/AudioDec/
|
||||||
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
|
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
|
||||||
|
@ -213,7 +216,7 @@ try:
|
||||||
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
|
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.inference.use_audiodec = False
|
cfg.inference.use_audiodec = False
|
||||||
print(str(e))
|
_logger.warning(str(e))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
|
@ -747,8 +750,8 @@ if __name__ == "__main__":
|
||||||
if args.print:
|
if args.print:
|
||||||
torch.set_printoptions(profile="full")
|
torch.set_printoptions(profile="full")
|
||||||
|
|
||||||
print( "Metadata:", artifact['metadata'] )
|
_logger.info(f"Metadata: {artifact['metadata']}" )
|
||||||
print( "Codes:", codes.shape, codes )
|
_logger.info(f"Codes: {codes.shape}, {codes}" )
|
||||||
# encode
|
# encode
|
||||||
else:
|
else:
|
||||||
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)
|
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)
|
||||||
|
|
|
@ -17,6 +17,9 @@ from ..models.lora import apply_lora, lora_load_state_dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
deepspeed_available = False
|
deepspeed_available = False
|
||||||
try:
|
try:
|
||||||
|
@ -55,7 +58,7 @@ def load_engines(training=True, **model_kwargs):
|
||||||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||||
|
|
||||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||||
print("Checkpoint missing, but weights found:", load_path)
|
_logger.warning("Checkpoint missing, but weights found:", load_path)
|
||||||
loads_state_dict = True
|
loads_state_dict = True
|
||||||
|
|
||||||
# load state early
|
# load state early
|
||||||
|
@ -64,7 +67,7 @@ def load_engines(training=True, **model_kwargs):
|
||||||
|
|
||||||
# check if config is defined in state, and re-initialize the model
|
# check if config is defined in state, and re-initialize the model
|
||||||
if "config" in state and False:
|
if "config" in state and False:
|
||||||
print("Model config definition in weights, re-loading...")
|
_logger.warning("Model config definition in weights, re-loading...")
|
||||||
config_state = state["config"]
|
config_state = state["config"]
|
||||||
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||||
|
|
||||||
|
@ -201,7 +204,7 @@ def load_engines(training=True, **model_kwargs):
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||||
if lora_path.exists():
|
if lora_path.exists():
|
||||||
print( "Loaded LoRA state dict:", lora_path )
|
_logger.info( "Loaded LoRA state dict:", lora_path )
|
||||||
|
|
||||||
state = torch_load(lora_path, device=cfg.device)
|
state = torch_load(lora_path, device=cfg.device)
|
||||||
state = state['lora' if 'lora' in state else 'module']
|
state = state['lora' if 'lora' in state else 'module']
|
||||||
|
|
|
@ -367,7 +367,7 @@ class Engines(dict[str, Engine]):
|
||||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||||
|
|
||||||
torch_save(state_dict, save_path)
|
torch_save(state_dict, save_path)
|
||||||
print(f"Exported {name} to {save_path}")
|
_logger.info(f"Exported {name} to {save_path}")
|
||||||
|
|
||||||
def save_checkpoint(self, tag=None):
|
def save_checkpoint(self, tag=None):
|
||||||
if not tag:
|
if not tag:
|
||||||
|
@ -385,7 +385,7 @@ class Engines(dict[str, Engine]):
|
||||||
try:
|
try:
|
||||||
engine.save_checkpoint(save_dir, tag=tag)
|
engine.save_checkpoint(save_dir, tag=tag)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to save checkpoint for engine {name}:', str(e))
|
_logger.warning(f'Failed to save checkpoint for engine {name}:', str(e))
|
||||||
|
|
||||||
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
||||||
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
||||||
|
@ -395,7 +395,7 @@ class Engines(dict[str, Engine]):
|
||||||
for d in checkpoints:
|
for d in checkpoints:
|
||||||
if not d.is_dir() or not d.exists():
|
if not d.is_dir() or not d.exists():
|
||||||
continue
|
continue
|
||||||
print("Removing", d)
|
_logger.info("Removing", d)
|
||||||
for p in d.iterdir():
|
for p in d.iterdir():
|
||||||
p.unlink()
|
p.unlink()
|
||||||
d.rmdir()
|
d.rmdir()
|
||||||
|
@ -490,7 +490,7 @@ class Engines(dict[str, Engine]):
|
||||||
res = feeder( engine=engine, batch=batch )
|
res = feeder( engine=engine, batch=batch )
|
||||||
break
|
break
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print("Forward", str(e))
|
_logger.error("Forward", str(e))
|
||||||
|
|
||||||
if "out of memory" not in str(e):
|
if "out of memory" not in str(e):
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
|
@ -532,7 +532,7 @@ class Engines(dict[str, Engine]):
|
||||||
try:
|
try:
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print("Backwards:", str(e))
|
_logger.error("Backwards:", str(e))
|
||||||
|
|
||||||
if "out of memory" not in str(e):
|
if "out of memory" not in str(e):
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
|
|
|
@ -117,7 +117,7 @@ class Engine(DeepSpeedEngine):
|
||||||
else:
|
else:
|
||||||
self.optimizer.set_lr(lr)
|
self.optimizer.set_lr(lr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
_logger.warning(str(e))
|
||||||
|
|
||||||
# we'll just have to live with the LoRA weights living within our main weights
|
# we'll just have to live with the LoRA weights living within our main weights
|
||||||
# they're easy to extract anyways
|
# they're easy to extract anyways
|
||||||
|
|
|
@ -2,6 +2,9 @@ import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import soundfile
|
import soundfile
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
@ -31,14 +34,13 @@ class TTS():
|
||||||
|
|
||||||
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||||
if config:
|
if config:
|
||||||
print("Loading YAML:", config)
|
_logger.info(f"Loading YAML: {config}")
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cfg.format( training=False )
|
cfg.format( training=False )
|
||||||
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while parsing config YAML:")
|
|
||||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||||
|
|
||||||
if amp is None:
|
if amp is None:
|
||||||
|
@ -73,7 +75,7 @@ class TTS():
|
||||||
|
|
||||||
self.engines.eval()
|
self.engines.eval()
|
||||||
self.symmap = get_phone_symmap()
|
self.symmap = get_phone_symmap()
|
||||||
print("Loaded model")
|
_logger.info("Loaded model")
|
||||||
|
|
||||||
def encode_text( self, text, language="en" ):
|
def encode_text( self, text, language="en" ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def get_model(config, training=True, **model_kwargs):
|
def get_model(config, training=True, **model_kwargs):
|
||||||
name = config.name
|
name = config.name
|
||||||
|
@ -53,7 +56,7 @@ def get_model(config, training=True, **model_kwargs):
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,9 @@ import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from ..emb.qnt import trim, encode_as_embedding
|
from ..emb.qnt import trim, encode_as_embedding
|
||||||
|
|
||||||
|
@ -379,7 +382,7 @@ def example_usage():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||||
|
|
||||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||||
|
|
||||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
@ -392,7 +395,7 @@ def example_usage():
|
||||||
scheduler = None
|
scheduler = None
|
||||||
|
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
print("Scheduler:", scheduler)
|
_logger.info(f"Scheduler: {scheduler}")
|
||||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||||
|
|
||||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
|
@ -425,7 +428,7 @@ def example_usage():
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
_logger.info(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_data(task=None):
|
def sample_data(task=None):
|
||||||
|
|
|
@ -17,6 +17,9 @@ import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from ..emb.qnt import trim, encode_as_embedding
|
from ..emb.qnt import trim, encode_as_embedding
|
||||||
|
|
||||||
|
@ -434,7 +437,7 @@ def example_usage():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||||
|
|
||||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||||
|
|
||||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
@ -447,7 +450,7 @@ def example_usage():
|
||||||
scheduler = None
|
scheduler = None
|
||||||
|
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
print("Scheduler:", scheduler)
|
_logger.info(f"Scheduler: {scheduler}")
|
||||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||||
|
|
||||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
|
@ -480,7 +483,7 @@ def example_usage():
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
_logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_data(task=None):
|
def sample_data(task=None):
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
from typing import Literal, overload, Optional, Tuple
|
from typing import Literal, overload, Optional, Tuple
|
||||||
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
@ -10,6 +12,8 @@ from transformers.cache_utils import Cache
|
||||||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS = []
|
AVAILABLE_ATTENTIONS = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -18,7 +22,7 @@ try:
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while querying for `flash_attention_2` support", e)
|
_logger.warning(f"Error while querying for `flash_attention_2` support: {str(e)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .attention.fused import attention as _fused_attention
|
from .attention.fused import attention as _fused_attention
|
||||||
|
@ -27,7 +31,7 @@ try:
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS.append("fused_attn")
|
AVAILABLE_ATTENTIONS.append("fused_attn")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while querying for `fused_attn` support", e)
|
_logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
|
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
|
||||||
|
@ -99,7 +103,7 @@ try:
|
||||||
has_flash_attn_with_paged = True
|
has_flash_attn_with_paged = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
print("Error while querying for `flash_attn` support", e)
|
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
@ -107,7 +111,7 @@ try:
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS.append("xformers")
|
AVAILABLE_ATTENTIONS.append("xformers")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while importing `xformers`", e)
|
_logger.warning(f"Error while importing `xformers`: {str(e)}")
|
||||||
|
|
||||||
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
||||||
if torch.backends.cuda.flash_sdp_enabled():
|
if torch.backends.cuda.flash_sdp_enabled():
|
||||||
|
|
|
@ -20,6 +20,9 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
@ -502,7 +505,7 @@ def example_usage():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||||
|
|
||||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||||
|
|
||||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
@ -515,7 +518,7 @@ def example_usage():
|
||||||
scheduler = None
|
scheduler = None
|
||||||
|
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
print("Scheduler:", scheduler)
|
_logger.info(f"Scheduler: {scheduler}")
|
||||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||||
|
|
||||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
|
@ -532,7 +535,7 @@ def example_usage():
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
_logger.info(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||||
|
|
|
@ -19,6 +19,9 @@ from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from ..emb.qnt import trim
|
from ..emb.qnt import trim
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class NAR(Base):
|
class NAR(Base):
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -361,7 +364,7 @@ def example_usage():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||||
|
|
||||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||||
|
|
||||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
@ -374,7 +377,7 @@ def example_usage():
|
||||||
scheduler = None
|
scheduler = None
|
||||||
|
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
print("Scheduler:", scheduler)
|
_logger.info(f"Scheduler: {scheduler}")
|
||||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||||
|
|
||||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
|
@ -391,7 +394,7 @@ def example_usage():
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
_logger.info(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=1000 ):
|
def sample( name, steps=1000 ):
|
||||||
|
|
|
@ -133,9 +133,6 @@ def run_eval(engines, eval_name, dl):
|
||||||
}
|
}
|
||||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||||
|
|
||||||
if cfg.trainer.no_logger:
|
|
||||||
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
|
||||||
else:
|
|
||||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,8 +157,8 @@ def train():
|
||||||
run_eval(engines, "subtrain", subtrain_dl)
|
run_eval(engines, "subtrain", subtrain_dl)
|
||||||
run_eval(engines, "val", val_dl)
|
run_eval(engines, "val", val_dl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error occurred while performing eval:", str(e))
|
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
||||||
print(traceback.format_exc())
|
_logger.warning(traceback.format_exc())
|
||||||
|
|
||||||
engines.train()
|
engines.train()
|
||||||
qnt.unload_model()
|
qnt.unload_model()
|
||||||
|
|
|
@ -20,7 +20,6 @@ def get_free_port():
|
||||||
|
|
||||||
_distributed_initialized = False
|
_distributed_initialized = False
|
||||||
def init_distributed( fn, *args, **kwargs ):
|
def init_distributed( fn, *args, **kwargs ):
|
||||||
#print("Initializing distributed...")
|
|
||||||
torch.cuda.set_device(local_rank())
|
torch.cuda.set_device(local_rank())
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
_distributed_initialized = True
|
_distributed_initialized = True
|
||||||
|
@ -29,8 +28,6 @@ def distributed_initialized():
|
||||||
return _distributed_initialized
|
return _distributed_initialized
|
||||||
|
|
||||||
def cleanup_distributed():
|
def cleanup_distributed():
|
||||||
#if not _distributed_initialized:
|
|
||||||
# return
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
|
@ -177,9 +177,6 @@ def train(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
metrics = str(stats)
|
metrics = str(stats)
|
||||||
|
|
||||||
if cfg.trainer.no_logger:
|
|
||||||
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
|
|
||||||
else:
|
|
||||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||||
|
|
||||||
command = _non_blocking_input()
|
command = _non_blocking_input()
|
||||||
|
@ -220,9 +217,9 @@ def train(
|
||||||
rate = float(command.split(" ")[-1])
|
rate = float(command.split(" ")[-1])
|
||||||
try:
|
try:
|
||||||
engines.set_lr(rate)
|
engines.set_lr(rate)
|
||||||
print("Updating LR to:", rate)
|
_logger.info(f"Updating LR to: {rate}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to set LR rate to:", rate, str(e))
|
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
|
||||||
|
|
||||||
if "export" in command:
|
if "export" in command:
|
||||||
train_dl.dataset.save_state_dict()
|
train_dl.dataset.save_state_dict()
|
||||||
|
|
|
@ -14,6 +14,9 @@ import random
|
||||||
import time
|
import time
|
||||||
import psutil
|
import psutil
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from coloredlogs import ColoredFormatter
|
from coloredlogs import ColoredFormatter
|
||||||
from logging import StreamHandler
|
from logging import StreamHandler
|
||||||
|
@ -296,7 +299,7 @@ def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
|
||||||
)
|
)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Replacing {name}.{k} to", klass)
|
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -330,7 +333,7 @@ def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
|
||||||
)
|
)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Replacing {name}.{k} to", klass)
|
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -360,7 +363,7 @@ def replace_attention( model, klass, target, mode="math", verbose=False ):
|
||||||
)
|
)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Replacing {name}.{k} to", klass)
|
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -491,7 +494,7 @@ def get_model_offload_policy(module, policy=None):
|
||||||
# does not fit in budget, increase device index
|
# does not fit in budget, increase device index
|
||||||
else:
|
else:
|
||||||
device_index += 1
|
device_index += 1
|
||||||
print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
|
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
|
||||||
|
|
||||||
# to-do: check that all modules are exhausted
|
# to-do: check that all modules are exhausted
|
||||||
assert module_index >= len(modules)
|
assert module_index >= len(modules)
|
||||||
|
@ -528,9 +531,9 @@ def offload_model( model, policy=None ):
|
||||||
if not not [*module.named_children()]:
|
if not not [*module.named_children()]:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
print( name, next(module.parameters()).device )
|
_logger.info( name, next(module.parameters()).device )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print( name, "?" )
|
_logger.info( name, "?" )
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,12 @@ from contextlib import contextmanager
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import logging
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
Embedding = torch.nn.Embedding
|
Embedding = torch.nn.Embedding
|
||||||
Linear = torch.nn.Linear
|
Linear = torch.nn.Linear
|
||||||
|
|
||||||
|
@ -95,7 +98,7 @@ if cfg.optimizations.tensorrt:
|
||||||
import torch_tensorrt
|
import torch_tensorrt
|
||||||
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Error while importing TensorRT:', str(e))
|
_logger.warning(f'Error while importing TensorRT: {str(e)}')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def compile_model(model, backend="auto"):
|
def compile_model(model, backend="auto"):
|
||||||
|
@ -111,14 +114,14 @@ def compile_model(model, backend="auto"):
|
||||||
try:
|
try:
|
||||||
from prodigyopt import Prodigy
|
from prodigyopt import Prodigy
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Error while importing Prodigyopt:', str(e))
|
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# https://github.com/facebookresearch/schedule_free/
|
# https://github.com/facebookresearch/schedule_free/
|
||||||
try:
|
try:
|
||||||
import schedulefree
|
import schedulefree
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Error while importing Schedule_Free:', str(e))
|
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# backwards compat
|
# backwards compat
|
||||||
|
|
Loading…
Reference in New Issue
Block a user