Skip to content

vllm.model_executor.models.transformers.causal

Transformers backend mixin for causal language models.

CausalMixin

Bases: VllmModelForTextGeneration

Source code in vllm/model_executor/models/transformers/causal.py
class CausalMixin(VllmModelForTextGeneration):
    def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
        # Skip VllmModelForTextGeneration.__init__ and call the next class in MRO
        super(VllmModelForTextGeneration, self).__init__(
            vllm_config=vllm_config, prefix=prefix
        )

        # Tell `Base.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
            self.skip_prefixes.append("lm_head.")

        if self.pp_group.is_last_rank:
            self.unpadded_vocab_size = self.text_config.vocab_size
            self.lm_head = ParallelLMHead(
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            if self.text_config.tie_word_embeddings:
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings()
                )

            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
            )
        else:
            self.lm_head = PPMissingLayer()

    def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None":
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "lm_head"),
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    unpadded_vocab_size, vocab_size, logit_scale
)

unpadded_vocab_size instance-attribute

unpadded_vocab_size = vocab_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/transformers/causal.py
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
    # Skip VllmModelForTextGeneration.__init__ and call the next class in MRO
    super(VllmModelForTextGeneration, self).__init__(
        vllm_config=vllm_config, prefix=prefix
    )

    # Tell `Base.load_weights` to skip
    # `lm_head` if the model has tied word embeddings
    if self.text_config.tie_word_embeddings:
        self.skip_prefixes.append("lm_head.")

    if self.pp_group.is_last_rank:
        self.unpadded_vocab_size = self.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.text_config.vocab_size,
            self.text_config.hidden_size,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        if self.text_config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(
                self.model.get_input_embeddings()
            )

        logit_scale = getattr(self.text_config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
        )
    else:
        self.lm_head = PPMissingLayer()

compute_logits

compute_logits(hidden_states: Tensor) -> Tensor | None
Source code in vllm/model_executor/models/transformers/causal.py
def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None":
    logits = self.logits_processor(self.lm_head, hidden_states)
    return logits