Skip to content

vllm.model_executor.layers.mamba.abstract

MambaBase

Bases: AttentionLayerBase

Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer.

Source code in vllm/model_executor/layers/mamba/abstract.py
class MambaBase(AttentionLayerBase):
    """
    Base class for Mamba-like layers which support the v1 engine.
    Inherit from this class if you implement a custom layer.
    """

    # Contains the KV cache (mamba state) for the layer
    # in the shape specified by `self.get_state_shape`.
    kv_cache: tuple[torch.Tensor, ...]

    @abstractmethod
    def get_state_shape(self) -> Iterable[tuple[int, ...]]:
        """
        Defines the shape of the state.
        For mamba layers this is usually a (conv_state, ssm_state) tuple.
        In this case, returns (conv_state_shape, ssm_state_shape).
        """
        pass

    @property
    @abstractmethod
    def mamba_type(self) -> str:
        pass

    @abstractmethod
    def get_attn_backend(self) -> type["AttentionBackend"]:
        """Get the attention backend class for this Mamba layer."""
        pass

    @abstractmethod
    def get_state_dtype(self) -> tuple[torch.dtype, ...]:
        pass

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
        if (
            vllm_config.speculative_config is not None
            and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
        ):
            raise NotImplementedError(
                "Mamba with speculative decoding is not supported yet."
            )
        mamba_block_size = vllm_config.cache_config.mamba_block_size
        page_size_padded = vllm_config.cache_config.mamba_page_size_padded
        return MambaSpec(
            shapes=self.get_state_shape(),
            dtypes=self.get_state_dtype(),
            block_size=mamba_block_size,
            page_size_padded=page_size_padded,
            mamba_type=self.mamba_type,
            num_speculative_blocks=(
                vllm_config.speculative_config.num_speculative_tokens
                if vllm_config.speculative_config
                else 0
            ),
        )

kv_cache instance-attribute

kv_cache: tuple[Tensor, ...]

mamba_type abstractmethod property

mamba_type: str

get_attn_backend abstractmethod

get_attn_backend() -> type[AttentionBackend]

Get the attention backend class for this Mamba layer.

Source code in vllm/model_executor/layers/mamba/abstract.py
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
    """Get the attention backend class for this Mamba layer."""
    pass

get_kv_cache_spec

get_kv_cache_spec(
    vllm_config: VllmConfig,
) -> KVCacheSpec | None
Source code in vllm/model_executor/layers/mamba/abstract.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
    if (
        vllm_config.speculative_config is not None
        and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
    ):
        raise NotImplementedError(
            "Mamba with speculative decoding is not supported yet."
        )
    mamba_block_size = vllm_config.cache_config.mamba_block_size
    page_size_padded = vllm_config.cache_config.mamba_page_size_padded
    return MambaSpec(
        shapes=self.get_state_shape(),
        dtypes=self.get_state_dtype(),
        block_size=mamba_block_size,
        page_size_padded=page_size_padded,
        mamba_type=self.mamba_type,
        num_speculative_blocks=(
            vllm_config.speculative_config.num_speculative_tokens
            if vllm_config.speculative_config
            else 0
        ),
    )

get_state_dtype abstractmethod

get_state_dtype() -> tuple[dtype, ...]
Source code in vllm/model_executor/layers/mamba/abstract.py
@abstractmethod
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
    pass

get_state_shape abstractmethod

get_state_shape() -> Iterable[tuple[int, ...]]

Defines the shape of the state. For mamba layers this is usually a (conv_state, ssm_state) tuple. In this case, returns (conv_state_shape, ssm_state_shape).

Source code in vllm/model_executor/layers/mamba/abstract.py
@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
    """
    Defines the shape of the state.
    For mamba layers this is usually a (conv_state, ssm_state) tuple.
    In this case, returns (conv_state_shape, ssm_state_shape).
    """
    pass