added knowledge distillation in the trainer (sadly it is not agnostic because of the grave mistake of further processing the batch within the forward pass, so subsequent calls do not match......)
This commit is contained in:
parent
4e21df8092
commit
23d402bf01
|
@ -115,7 +115,8 @@ class BaseConfig:
|
|||
raise Exception(f'Model path does not exist: {model_path}')
|
||||
|
||||
# load state dict and copy its stored model config
|
||||
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path, "attention": "auto" } ] if model_path and model_path.exists() else []
|
||||
model_kwargs = { "attention": "auto", "training": False, "teacher": False }
|
||||
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } | model_kwargs ] if model_path and model_path.exists() else []
|
||||
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
|
||||
|
||||
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }
|
||||
|
@ -279,6 +280,8 @@ class ModelExperimentalSettings:
|
|||
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
||||
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
||||
|
||||
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
|
@ -291,7 +294,9 @@ class Model:
|
|||
tones: int = 1 # defined tones (unsued)
|
||||
experts: int = 1 # for mixtral / retnet-ts
|
||||
arch_type: str = "llama" # underling LM architecture used
|
||||
training: bool = True # I really need to attend to this
|
||||
training: bool = False # I really need to attend to this
|
||||
teacher: bool = False # if this is to be treated as a teacher
|
||||
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "auto" # for llama arch_types: attention used
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
|
@ -1006,6 +1011,11 @@ class Config(BaseConfig):
|
|||
if isinstance( model.experimental, dict ):
|
||||
model.experimental = ModelExperimentalSettings(**model.experimental)
|
||||
|
||||
if model.teacher:
|
||||
model.training = False
|
||||
if model.training:
|
||||
model.teacher = False
|
||||
|
||||
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||
self.hyperparameters.scheduler_type = ""
|
||||
|
|
|
@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
stats = None
|
||||
lora = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training or model.config.teacher
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
|
||||
|
||||
|
@ -327,6 +327,12 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.optimizations.model_offloading:
|
||||
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
|
||||
|
||||
# set to train/eval
|
||||
if engine.hyper_config.training:
|
||||
engine.module.train()
|
||||
else:
|
||||
engine.module.eval()
|
||||
|
||||
# setup wandb
|
||||
if engine._training and cfg.trainer.wandb and wandb is not None:
|
||||
key_name = name
|
||||
|
|
|
@ -104,6 +104,12 @@ class Engine():
|
|||
return True
|
||||
return self.hyper_config.training
|
||||
|
||||
@property
|
||||
def _teacher(self):
|
||||
if not hasattr(self, "hyper_config"):
|
||||
return False
|
||||
return self.hyper_config.teacher
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
@ -352,6 +358,9 @@ class Engines(dict[str, Engine]):
|
|||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}"
|
||||
|
||||
config_dict = dict(**config.__dict__)
|
||||
config_dict |= {"experimental": config.experimental.__dict__}
|
||||
|
||||
state_dict = {
|
||||
'module': module,
|
||||
'lora': lora,
|
||||
|
@ -362,7 +371,7 @@ class Engines(dict[str, Engine]):
|
|||
"tokens_processed": engine.tokens_processed,
|
||||
},
|
||||
"userdata": userdata,
|
||||
"config": config.__dict__ | {"experimental": config.experimental.__dict__} # i hate implicit aliasing rules
|
||||
"config": config_dict
|
||||
}
|
||||
|
||||
if lora is None:
|
||||
|
@ -478,8 +487,17 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.trainer.gc_mode == 'step':
|
||||
do_gc()
|
||||
|
||||
# preiterate to get teacher
|
||||
teacher = None
|
||||
for name, engine in self.items():
|
||||
if not engine._training:
|
||||
if not engine._teacher:
|
||||
continue
|
||||
teacher = engine.module
|
||||
break
|
||||
|
||||
for name, engine in self.items():
|
||||
# only models that we're training
|
||||
if not engine._training or engine._teacher:
|
||||
continue
|
||||
|
||||
device = engine.device
|
||||
|
@ -493,10 +511,10 @@ class Engines(dict[str, Engine]):
|
|||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if not cfg.trainer.check_for_oom:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
||||
else:
|
||||
try:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
||||
except RuntimeError as e:
|
||||
_logger.error(f"Forward: {str(e)}")
|
||||
|
||||
|
|
|
@ -92,6 +92,10 @@ class Engine(DeepSpeedEngine):
|
|||
def _training(self):
|
||||
return self.hyper_config.training
|
||||
|
||||
@property
|
||||
def _teacher(self):
|
||||
return self.hyper_config.teacher
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
|
|
@ -32,6 +32,10 @@ from ..samplers import cfg_logits
|
|||
text_task = [ "stt" ]
|
||||
|
||||
class AR_NAR(Base):
|
||||
# yikes
|
||||
def forward_super(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
# parse inputs for training
|
||||
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
|
||||
def forward_train(
|
||||
|
@ -44,6 +48,8 @@ class AR_NAR(Base):
|
|||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
teacher = None,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list is not None:
|
||||
|
@ -198,6 +204,7 @@ class AR_NAR(Base):
|
|||
return super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
teacher=teacher,
|
||||
)
|
||||
|
||||
def forward_nar_masked(
|
||||
|
@ -834,7 +841,8 @@ class AR_NAR(Base):
|
|||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | int | None = None,
|
||||
training: bool | None = None,
|
||||
teacher = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -871,8 +879,8 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
|
||||
teacher=teacher,
|
||||
)
|
||||
|
||||
# is NAR
|
||||
|
|
|
@ -38,7 +38,7 @@ from ..emb.qnt import encode_as_embedding
|
|||
from ..data import get_task_symmap
|
||||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
|
@ -389,12 +389,13 @@ class Base(nn.Module):
|
|||
|
||||
l_padding: int = 0,
|
||||
|
||||
training = True,
|
||||
training = True,
|
||||
attention = None,
|
||||
config = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.training = training
|
||||
self.teaching = False
|
||||
self.config = config
|
||||
|
||||
self.n_text_tokens = n_text_tokens
|
||||
|
@ -428,6 +429,11 @@ class Base(nn.Module):
|
|||
if not attention:
|
||||
attention = self.config.attention if self.config is not None else "auto"
|
||||
|
||||
# crunge
|
||||
if self.config is not None and config.teacher:
|
||||
self.teaching = True
|
||||
self.training = False
|
||||
|
||||
attention_backend = attention
|
||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
|
@ -436,6 +442,7 @@ class Base(nn.Module):
|
|||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
|
||||
|
||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||
|
@ -485,6 +492,7 @@ class Base(nn.Module):
|
|||
self.masking_ratio = masking_ratio
|
||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
self.teacher_alpha = teacher_alpha
|
||||
|
||||
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||
"""
|
||||
|
@ -1265,6 +1273,8 @@ class Base(nn.Module):
|
|||
logits,
|
||||
|
||||
quant_levels: list[int] | None = None,
|
||||
compute_hard_loss = True,
|
||||
compute_acc = True,
|
||||
):
|
||||
loss = {}
|
||||
stats = {}
|
||||
|
@ -1297,6 +1307,7 @@ class Base(nn.Module):
|
|||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
classifier_level = None
|
||||
output_len = 0
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
|
@ -1354,8 +1365,11 @@ class Base(nn.Module):
|
|||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# deduce if a name for a task is an input or output
|
||||
if self.ignore_inputs_for_loss and name != task_outputs.get(task_type, name):
|
||||
ignored = True
|
||||
if name != task_outputs.get(task_type, name):
|
||||
if self.ignore_inputs_for_loss:
|
||||
ignored = True
|
||||
else:
|
||||
output_len = seq_len
|
||||
|
||||
if ignored:
|
||||
# pruned
|
||||
|
@ -1378,20 +1392,20 @@ class Base(nn.Module):
|
|||
logit = logit[..., :-l, :]
|
||||
token = token[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
||||
if f'{name}.nll' not in loss:
|
||||
loss[f'{name}.nll'] = []
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
|
||||
if f'{name}.nll' not in loss:
|
||||
loss[f'{name}.nll'] = []
|
||||
loss[f'{name}.nll'].append( nll )
|
||||
|
||||
if f'{name}.acc' not in stats:
|
||||
stats[f'{name}.acc'] = []
|
||||
|
||||
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
metrics = self.accuracy_metric( logit, token )
|
||||
|
||||
loss[f'{name}.nll'].append( nll )
|
||||
stats[f'{name}.acc'].append( metrics )
|
||||
if compute_acc:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
metrics = self.accuracy_metric( logit, token )
|
||||
if f'{name}.acc' not in stats:
|
||||
stats[f'{name}.acc'] = []
|
||||
stats[f'{name}.acc'].append( metrics )
|
||||
# add to list
|
||||
else:
|
||||
target.append( token )
|
||||
|
@ -1407,21 +1421,21 @@ class Base(nn.Module):
|
|||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
target = target[..., l:] # ...predicts token n + 1
|
||||
|
||||
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
metrics = self.accuracy_metric( logit, target )
|
||||
if compute_acc:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
metrics = self.accuracy_metric( logit, target )
|
||||
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
|
||||
loss["nll"].append( nll )
|
||||
stats["acc"].append( metrics )
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
|
@ -1440,6 +1454,8 @@ class Base(nn.Module):
|
|||
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
|
||||
teacher = None,
|
||||
):
|
||||
# return early if it's "good" enough"
|
||||
# lambda because we need to capture the classifier_levels and mask
|
||||
|
@ -1492,6 +1508,7 @@ class Base(nn.Module):
|
|||
x, mask = list_to_tensor(x_list)
|
||||
|
||||
training = self.training
|
||||
teaching = self.teaching
|
||||
device = x.device
|
||||
batch_size = len(x_list)
|
||||
|
||||
|
@ -1566,8 +1583,14 @@ class Base(nn.Module):
|
|||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
|
||||
# compute loss if the target is given
|
||||
if training:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
if not training:
|
||||
loss = None
|
||||
stats = None
|
||||
|
||||
self.loss = None
|
||||
self.stats = None
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels, compute_hard_loss=training, compute_acc=training )
|
||||
|
||||
# compute it as an aux-loss
|
||||
if self.layerskip:
|
||||
|
@ -1590,15 +1613,41 @@ class Base(nn.Module):
|
|||
# to-do: instead make the cirriculum rely on samples processed instead of steps
|
||||
self.training_steps += 1 # batch_size
|
||||
|
||||
# get soft targets from teacher
|
||||
# it might be better to compute these once instead of per-engine, but realistically who is actually training multiple models
|
||||
if teacher is not None:
|
||||
with torch.no_grad():
|
||||
teacher_output = teacher.forward_super(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
soft_loss = [
|
||||
F.kl_div(
|
||||
F.log_softmax( student, dim=-1 ).unsqueeze(0),
|
||||
F.softmax( teacher, dim=-1 ).unsqueeze(0),
|
||||
reduction='batchmean'
|
||||
)
|
||||
for student, teacher in zip( logits, teacher_output.logits )
|
||||
]
|
||||
soft_loss = torch.stack([*soft_loss]).sum() / batch_size
|
||||
|
||||
# mix if not nan
|
||||
if not torch.isnan(soft_loss).any():
|
||||
alpha = self.teacher_alpha
|
||||
loss['kl'] = alpha * soft_loss
|
||||
for k in loss.keys():
|
||||
loss[k] *= (1.0 - alpha)
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.aux_loss is not None:
|
||||
loss["aux_loss"] = output.aux_loss
|
||||
if output.loss is not None:
|
||||
loss["aux_loss"] = output.loss
|
||||
|
||||
self.loss = loss
|
||||
self.stats = stats
|
||||
|
||||
# rewrap, because we're modifying the logits here
|
||||
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states, exited_layer)
|
||||
return Logits(logits, output.state, loss, output.attentions, hidden_states, exited_layer)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
|
|
@ -26,7 +26,7 @@ _logger = logging.getLogger(__name__)
|
|||
|
||||
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
||||
|
||||
def train_feeder(engine, batch):
|
||||
def train_feeder(engine, batch, teacher=None):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
@ -40,6 +40,7 @@ def train_feeder(engine, batch):
|
|||
task_list=batch["task"],
|
||||
|
||||
training=True,
|
||||
teacher=teacher,
|
||||
)
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
|
|
Loading…
Reference in New Issue
Block a user