Removed unnecessary tmp variable
This commit is contained in:
parent
ad3ae44108
commit
1824e9ee3a
|
@ -282,10 +282,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
|
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
|
||||||
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
|
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
|
||||||
|
|
||||||
tmp = -opts.CLIP_stop_at_last_layers
|
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
|
if opts.CLIP_stop_at_last_layers > 1:
|
||||||
if tmp < -1:
|
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||||
z = outputs.hidden_states[tmp]
|
|
||||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||||
else:
|
else:
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
|
Loading…
Reference in New Issue
Block a user