added knowledge distillation in the trainer (sadly it is not agnostic because of the grave mistake of further processing the batch within the forward pass, so subsequent calls do not match......)

This commit is contained in:
mrq 2024-12-05 23:05:52 -06:00
parent 4e21df8092
commit 23d402bf01
7 changed files with 142 additions and 46 deletions

View File

@ -115,7 +115,8 @@ class BaseConfig:
raise Exception(f'Model path does not exist: {model_path}')
# load state dict and copy its stored model config
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path, "attention": "auto" } ] if model_path and model_path.exists() else []
model_kwargs = { "attention": "auto", "training": False, "teacher": False }
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } | model_kwargs ] if model_path and model_path.exists() else []
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }
@ -279,6 +280,8 @@ class ModelExperimentalSettings:
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
# I really need to clean this up
@dataclass()
class Model:
@ -291,7 +294,9 @@ class Model:
tones: int = 1 # defined tones (unsued)
experts: int = 1 # for mixtral / retnet-ts
arch_type: str = "llama" # underling LM architecture used
training: bool = True # I really need to attend to this
training: bool = False # I really need to attend to this
teacher: bool = False # if this is to be treated as a teacher
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "auto" # for llama arch_types: attention used
dropout: float = 0.1 # adjustable dropout value
@ -1006,6 +1011,11 @@ class Config(BaseConfig):
if isinstance( model.experimental, dict ):
model.experimental = ModelExperimentalSettings(**model.experimental)
if model.teacher:
model.training = False
if model.training:
model.teacher = False
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
self.hyperparameters.scheduler_type = ""

View File

@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs):
stats = None
lora = None
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
inferencing = cfg.mode == "inferencing" or not model.config.training or not training or model.config.teacher
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
@ -327,6 +327,12 @@ def load_engines(training=True, **model_kwargs):
if cfg.optimizations.model_offloading:
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
# set to train/eval
if engine.hyper_config.training:
engine.module.train()
else:
engine.module.eval()
# setup wandb
if engine._training and cfg.trainer.wandb and wandb is not None:
key_name = name

View File

@ -104,6 +104,12 @@ class Engine():
return True
return self.hyper_config.training
@property
def _teacher(self):
if not hasattr(self, "hyper_config"):
return False
return self.hyper_config.teacher
@property
def global_step(self):
return self.global_steps
@ -352,6 +358,9 @@ class Engines(dict[str, Engine]):
lora, module = lora_get_state_dict( module, split = True )
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}"
config_dict = dict(**config.__dict__)
config_dict |= {"experimental": config.experimental.__dict__}
state_dict = {
'module': module,
'lora': lora,
@ -362,7 +371,7 @@ class Engines(dict[str, Engine]):
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata,
"config": config.__dict__ | {"experimental": config.experimental.__dict__} # i hate implicit aliasing rules
"config": config_dict
}
if lora is None:
@ -478,8 +487,17 @@ class Engines(dict[str, Engine]):
if cfg.trainer.gc_mode == 'step':
do_gc()
# preiterate to get teacher
teacher = None
for name, engine in self.items():
if not engine._training:
if not engine._teacher:
continue
teacher = engine.module
break
for name, engine in self.items():
# only models that we're training
if not engine._training or engine._teacher:
continue
device = engine.device
@ -493,10 +511,10 @@ class Engines(dict[str, Engine]):
n_ooms = torch.zeros([], device=device)
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch )
res = feeder( engine=engine, batch=batch, teacher=teacher )
else:
try:
res = feeder( engine=engine, batch=batch )
res = feeder( engine=engine, batch=batch, teacher=teacher )
except RuntimeError as e:
_logger.error(f"Forward: {str(e)}")

View File

@ -92,6 +92,10 @@ class Engine(DeepSpeedEngine):
def _training(self):
return self.hyper_config.training
@property
def _teacher(self):
return self.hyper_config.teacher
@property
def global_step(self):
return self.global_steps

View File

@ -32,6 +32,10 @@ from ..samplers import cfg_logits
text_task = [ "stt" ]
class AR_NAR(Base):
# yikes
def forward_super(self, *args, **kwargs):
return super().forward(*args, **kwargs)
# parse inputs for training
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
def forward_train(
@ -44,6 +48,8 @@ class AR_NAR(Base):
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
teacher = None,
):
# deduce batch_size
if text_list is not None:
@ -198,6 +204,7 @@ class AR_NAR(Base):
return super().forward(
inputs=inputs,
quant_levels=quant_levels,
teacher=teacher,
)
def forward_nar_masked(
@ -834,7 +841,8 @@ class AR_NAR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | int | None = None,
training: bool | None = None,
teacher = None,
disable_tqdm=False,
use_lora=None,
@ -871,8 +879,8 @@ class AR_NAR(Base):
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
teacher=teacher,
)
# is NAR

View File

@ -38,7 +38,7 @@ from ..emb.qnt import encode_as_embedding
from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
Logits = namedtuple('Logits', ['logits', 'state', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
@ -389,12 +389,13 @@ class Base(nn.Module):
l_padding: int = 0,
training = True,
training = True,
attention = None,
config = None,
):
super().__init__()
self.training = training
self.teaching = False
self.config = config
self.n_text_tokens = n_text_tokens
@ -428,6 +429,11 @@ class Base(nn.Module):
if not attention:
attention = self.config.attention if self.config is not None else "auto"
# crunge
if self.config is not None and config.teacher:
self.teaching = True
self.training = False
attention_backend = attention
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
@ -436,6 +442,7 @@ class Base(nn.Module):
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
@ -485,6 +492,7 @@ class Base(nn.Module):
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.teacher_alpha = teacher_alpha
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
"""
@ -1265,6 +1273,8 @@ class Base(nn.Module):
logits,
quant_levels: list[int] | None = None,
compute_hard_loss = True,
compute_acc = True,
):
loss = {}
stats = {}
@ -1297,6 +1307,7 @@ class Base(nn.Module):
task_type = "tts"
dropout_mask = None
classifier_level = None
output_len = 0
for name, input in batch:
if name == "task":
@ -1354,8 +1365,11 @@ class Base(nn.Module):
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if self.ignore_inputs_for_loss and name != task_outputs.get(task_type, name):
ignored = True
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
else:
output_len = seq_len
if ignored:
# pruned
@ -1378,20 +1392,20 @@ class Base(nn.Module):
logit = logit[..., :-l, :]
token = token[..., l:] # shift sequence to the right by one (or causal chunk size)
if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = []
if compute_hard_loss:
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = []
loss[f'{name}.nll'].append( nll )
if f'{name}.acc' not in stats:
stats[f'{name}.acc'] = []
nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
metrics = self.accuracy_metric( logit, token )
loss[f'{name}.nll'].append( nll )
stats[f'{name}.acc'].append( metrics )
if compute_acc:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
metrics = self.accuracy_metric( logit, token )
if f'{name}.acc' not in stats:
stats[f'{name}.acc'] = []
stats[f'{name}.acc'].append( metrics )
# add to list
else:
target.append( token )
@ -1407,21 +1421,21 @@ class Base(nn.Module):
logit = logit[..., :-l, :] # shift the target so that token n...
target = target[..., l:] # ...predicts token n + 1
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
if compute_hard_loss:
nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index )
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
else:
metrics = self.accuracy_metric( logit, target )
if compute_acc:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) )
else:
metrics = self.accuracy_metric( logit, target )
if 'nll' not in loss:
loss['nll'] = []
if 'acc' not in stats:
stats['acc'] = []
loss["nll"].append( nll )
stats["acc"].append( metrics )
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
@ -1440,6 +1454,8 @@ class Base(nn.Module):
output_attentions: bool = False,
output_hidden_states: bool = False,
teacher = None,
):
# return early if it's "good" enough"
# lambda because we need to capture the classifier_levels and mask
@ -1492,6 +1508,7 @@ class Base(nn.Module):
x, mask = list_to_tensor(x_list)
training = self.training
teaching = self.teaching
device = x.device
batch_size = len(x_list)
@ -1566,8 +1583,14 @@ class Base(nn.Module):
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# compute loss if the target is given
if training:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
if not training:
loss = None
stats = None
self.loss = None
self.stats = None
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels, compute_hard_loss=training, compute_acc=training )
# compute it as an aux-loss
if self.layerskip:
@ -1590,15 +1613,41 @@ class Base(nn.Module):
# to-do: instead make the cirriculum rely on samples processed instead of steps
self.training_steps += 1 # batch_size
# get soft targets from teacher
# it might be better to compute these once instead of per-engine, but realistically who is actually training multiple models
if teacher is not None:
with torch.no_grad():
teacher_output = teacher.forward_super(
inputs=inputs,
quant_levels=quant_levels,
)
soft_loss = [
F.kl_div(
F.log_softmax( student, dim=-1 ).unsqueeze(0),
F.softmax( teacher, dim=-1 ).unsqueeze(0),
reduction='batchmean'
)
for student, teacher in zip( logits, teacher_output.logits )
]
soft_loss = torch.stack([*soft_loss]).sum() / batch_size
# mix if not nan
if not torch.isnan(soft_loss).any():
alpha = self.teacher_alpha
loss['kl'] = alpha * soft_loss
for k in loss.keys():
loss[k] *= (1.0 - alpha)
# include any additional losses (for example: MoE router)
if output.aux_loss is not None:
loss["aux_loss"] = output.aux_loss
if output.loss is not None:
loss["aux_loss"] = output.loss
self.loss = loss
self.stats = stats
# rewrap, because we're modifying the logits here
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states, exited_layer)
return Logits(logits, output.state, loss, output.attentions, hidden_states, exited_layer)
def sample(
self,

View File

@ -26,7 +26,7 @@ _logger = logging.getLogger(__name__)
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch):
def train_feeder(engine, batch, teacher=None):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
@ -40,6 +40,7 @@ def train_feeder(engine, batch):
task_list=batch["task"],
training=True,
teacher=teacher,
)
losses = engine.gather_attribute("loss")