Optimized code for Ignoring last CLIP layers

This commit is contained in:
Fampai 2022-10-08 16:32:05 -04:00 committed by AUTOMATIC1111
parent 6c383d2e82
commit e59c66c008

View File

@ -282,11 +282,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
tmp = -opts.CLIP_ignore_last_layers tmp = -opts.CLIP_stop_at_last_layers
if (opts.CLIP_ignore_last_layers == 0):
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
z = outputs.last_hidden_state
else:
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
z = outputs.hidden_states[tmp] z = outputs.hidden_states[tmp]
z = self.wrapped.transformer.text_model.final_layer_norm(z) z = self.wrapped.transformer.text_model.final_layer_norm(z)