Skip to content

vllm.attention.backends.abstract

T module-attribute

T = TypeVar('T', bound=AttentionMetadata)

AttentionBackend

Bases: ABC

Abstract class for attention backends.

Source code in vllm/attention/backends/abstract.py
class AttentionBackend(ABC):
    """Abstract class for attention backends."""

    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        raise NotImplementedError

    @classmethod
    def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
        return cls.get_impl_cls().get_supported_kernel_block_size()

    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

    @staticmethod
    @abstractmethod
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        raise NotImplementedError

    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = False

full_cls_name classmethod

full_cls_name() -> tuple[str, str]
Source code in vllm/attention/backends/abstract.py
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
    return (cls.__module__, cls.__qualname__)

get_builder_cls abstractmethod staticmethod

get_builder_cls()
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
    raise NotImplementedError

get_impl_cls abstractmethod staticmethod

get_impl_cls() -> type[AttentionImpl]
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
    raise NotImplementedError

get_kv_cache_shape abstractmethod staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    raise NotImplementedError

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order() -> tuple[int, ...]
Source code in vllm/attention/backends/abstract.py
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
    raise NotImplementedError

get_metadata_cls abstractmethod staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    raise NotImplementedError

get_name abstractmethod staticmethod

get_name() -> str
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_name() -> str:
    raise NotImplementedError

get_supported_kernel_block_size classmethod

get_supported_kernel_block_size() -> list[int | MultipleOf]
Source code in vllm/attention/backends/abstract.py
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
    return cls.get_impl_cls().get_supported_kernel_block_size()

make_metadata classmethod

make_metadata(*args, **kwargs) -> AttentionMetadata
Source code in vllm/attention/backends/abstract.py
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
    return cls.get_metadata_cls()(*args, **kwargs)

AttentionImpl

Bases: ABC, Generic[T]

Source code in vllm/attention/backends/abstract.py
class AttentionImpl(ABC, Generic[T]):
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group

            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
        return self

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
        kv_cache_dtype: str = "auto",
        logits_soft_cap: float | None = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
    ) -> None:
        raise NotImplementedError

    @staticmethod
    def get_supported_kernel_block_size() -> list[int | MultipleOf]:
        # TODO: implement this function for all backends.
        return [MultipleOf(1)]

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        raise NotImplementedError

    def fused_output_quant_supported(self, quant_key: QuantKey):
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

        :param quant_key: QuantKey object that describes the quantization op
        :return: is fusion supported for this type of quantization
        """
        return False

    def supports_quant_query_input(self) -> bool:
        """
        Check if this attention implementation supports pre-quantized query input.

        When True, the attention layer will quantize queries before passing them
        to this backend, allowing torch.compile to fuse the quantization with
        previous operations. This is typically supported when using FP8 KV cache
        with compatible attention kernels (e.g., TRT-LLM).
        TODO add support to more backends:
        https://github.com/vllm-project/vllm/issues/25584

        Returns:
            bool: True if the implementation can accept pre-quantized queries.
        """
        return False

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        pass

can_return_lse_for_decode class-attribute instance-attribute

can_return_lse_for_decode: bool = False

dcp_rank instance-attribute

dcp_rank: int

dcp_world_size instance-attribute

dcp_world_size: int

need_to_return_lse_for_decode class-attribute instance-attribute

need_to_return_lse_for_decode: bool = False

__init__ abstractmethod

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    sliding_window: int | None = None,
    kv_cache_dtype: str = "auto",
    logits_soft_cap: float | None = None,
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: str | None = None,
) -> None
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    sliding_window: int | None = None,
    kv_cache_dtype: str = "auto",
    logits_soft_cap: float | None = None,
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: str | None = None,
) -> None:
    raise NotImplementedError

__new__

__new__(*args, **kwargs)
Source code in vllm/attention/backends/abstract.py
def __new__(cls, *args, **kwargs):
    # use __new__ so that all subclasses will call this
    self = super().__new__(cls)
    try:
        from vllm.distributed.parallel_state import get_dcp_group

        self.dcp_world_size = get_dcp_group().world_size
        self.dcp_rank = get_dcp_group().rank_in_group
    except AssertionError:
        # DCP might not be initialized in testing
        self.dcp_world_size = 1
        self.dcp_rank = 0
    self.need_to_return_lse_for_decode = (
        self.dcp_world_size > 1 and self.can_return_lse_for_decode
    )
    return self

forward abstractmethod

forward(
    layer: AttentionLayer,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: T,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: T,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    raise NotImplementedError

fused_output_quant_supported

fused_output_quant_supported(quant_key: QuantKey)

Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it.

:param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization

Source code in vllm/attention/backends/abstract.py
def fused_output_quant_supported(self, quant_key: QuantKey):
    """
    Does this attention implementation support fused output quantization.
    This is used by the AttnFusionPass to only fuse output quantization
    onto implementations that support it.

    :param quant_key: QuantKey object that describes the quantization op
    :return: is fusion supported for this type of quantization
    """
    return False

get_supported_kernel_block_size staticmethod

get_supported_kernel_block_size() -> list[int | MultipleOf]
Source code in vllm/attention/backends/abstract.py
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
    # TODO: implement this function for all backends.
    return [MultipleOf(1)]

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/attention/backends/abstract.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    pass

supports_quant_query_input

supports_quant_query_input() -> bool

Check if this attention implementation supports pre-quantized query input.

When True, the attention layer will quantize queries before passing them to this backend, allowing torch.compile to fuse the quantization with previous operations. This is typically supported when using FP8 KV cache with compatible attention kernels (e.g., TRT-LLM). TODO add support to more backends: https://github.com/vllm-project/vllm/issues/25584

Returns:

Name Type Description
bool bool

True if the implementation can accept pre-quantized queries.

Source code in vllm/attention/backends/abstract.py
def supports_quant_query_input(self) -> bool:
    """
    Check if this attention implementation supports pre-quantized query input.

    When True, the attention layer will quantize queries before passing them
    to this backend, allowing torch.compile to fuse the quantization with
    previous operations. This is typically supported when using FP8 KV cache
    with compatible attention kernels (e.g., TRT-LLM).
    TODO add support to more backends:
    https://github.com/vllm-project/vllm/issues/25584

    Returns:
        bool: True if the implementation can accept pre-quantized queries.
    """
    return False

AttentionLayer

Bases: Protocol

Source code in vllm/attention/backends/abstract.py
class AttentionLayer(Protocol):
    _q_scale: torch.Tensor
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
    _q_scale_float: float
    _k_scale_float: float
    _v_scale_float: float
    _prob_scale: torch.Tensor

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor: ...

_k_scale instance-attribute

_k_scale: Tensor

_k_scale_float instance-attribute

_k_scale_float: float

_prob_scale instance-attribute

_prob_scale: Tensor

_q_scale instance-attribute

_q_scale: Tensor

_q_scale_float instance-attribute

_q_scale_float: float

_v_scale instance-attribute

_v_scale: Tensor

_v_scale_float instance-attribute

_v_scale_float: float

forward

forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: AttentionMetadata,
) -> Tensor
Source code in vllm/attention/backends/abstract.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: AttentionMetadata,
) -> torch.Tensor: ...

AttentionMetadata

Source code in vllm/attention/backends/abstract.py
class AttentionMetadata:
    pass

AttentionType

Attention type. Use string to be compatible with torch.compile.

Source code in vllm/attention/backends/abstract.py
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """

    DECODER = "decoder"
    """Decoder attention between previous layer Q/K/V."""
    ENCODER = "encoder"
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
    ENCODER_ONLY = "encoder_only"
    """Encoder attention between previous layer Q/K/V."""
    ENCODER_DECODER = "encoder_decoder"
    """Attention between dec. Q and enc. K/V for encoder-decoder."""

DECODER class-attribute instance-attribute

DECODER = 'decoder'

Decoder attention between previous layer Q/K/V.

ENCODER class-attribute instance-attribute

ENCODER = 'encoder'

Encoder attention between previous layer Q/K/V for encoder-decoder.

ENCODER_DECODER class-attribute instance-attribute

ENCODER_DECODER = 'encoder_decoder'

Attention between dec. Q and enc. K/V for encoder-decoder.

ENCODER_ONLY class-attribute instance-attribute

ENCODER_ONLY = 'encoder_only'

Encoder attention between previous layer Q/K/V.

MLAAttentionImpl

Bases: AttentionImpl[T], Generic[T]

Source code in vllm/attention/backends/abstract.py
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        q_lora_rank: int | None,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
        indexer: object | None = None,
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        raise NotImplementedError

__init__ abstractmethod

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    q_lora_rank: int | None,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
    indexer: object | None = None,
) -> None
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    # MLA Specific Arguments
    q_lora_rank: int | None,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
    indexer: object | None = None,
) -> None:
    raise NotImplementedError

forward abstractmethod

forward(
    layer: AttentionLayer,
    hidden_states_or_cq: Tensor,
    kv_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: T,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def forward(
    self,
    layer: AttentionLayer,
    hidden_states_or_cq: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: T,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    raise NotImplementedError

MultipleOf

Source code in vllm/attention/backends/abstract.py
class MultipleOf:
    base: int

    def __init__(self, base: int):
        self.base = base

base instance-attribute

base: int = base

__init__

__init__(base: int)
Source code in vllm/attention/backends/abstract.py
def __init__(self, base: int):
    self.base = base

is_quantized_kv_cache

is_quantized_kv_cache(kv_cache_dtype: str) -> bool
Source code in vllm/attention/backends/abstract.py
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"