Skip to content

vllm.model_executor.layers.fused_moe.trtllm_moe

TrtLlmGenExperts

Bases: FusedMoEPermuteExpertsUnpermute

Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
    def __init__(
        self,
        moe: FusedMoEConfig,
        quant_config: FusedMoEQuantConfig,
        gemm1_alpha,
        gemm1_beta,
        gemm1_clamp_limit,
        max_capture_size,
    ):
        super().__init__(quant_config)
        self.moe = moe
        self.gemm1_alpha = gemm1_alpha
        self.gemm1_beta = gemm1_beta
        self.gemm1_clamp_limit = gemm1_clamp_limit
        self.max_capture_size = max_capture_size

    @property
    def activation_formats(
        self,
    ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
        return (
            mk.FusedMoEActivationFormat.Standard,
            mk.FusedMoEActivationFormat.Standard,
        )

    def supports_chunking(self) -> bool:
        return True

    def supports_expert_map(self) -> bool:
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        # The workspaces for this implementation are managed by flashinfer.
        workspace1 = (0,)
        workspace2 = (0,)
        output = (M, K)
        return (workspace1, workspace2, output)

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        topk = topk_ids.size(-1)
        local_num_experts = w1.size(0)
        intermediate_size = w2.size(1)
        local_expert_offset = self.moe.ep_rank * local_num_experts

        x_quant = hidden_states
        x_scale = a1q_scale
        if x_scale is not None:
            x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1)

        packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
            torch.bfloat16
        ).view(torch.int16)

        assert self.w1_scale is not None
        assert self.w2_scale is not None
        kwargs = {
            "topk_ids": packed_tensor,
            "routing_bias": None,
            "hidden_states": x_quant,
            "hidden_states_scale": x_scale,
            "gemm1_weights": w1,
            "gemm1_weights_scale": self.w1_scale,
            "gemm1_bias": self.w1_bias,
            "gemm1_alpha": self.gemm1_alpha,
            "gemm1_beta": self.gemm1_beta,
            "gemm1_clamp_limit": self.gemm1_clamp_limit,
            "gemm2_weights": w2,
            "gemm2_weights_scale": self.w2_scale,
            "gemm2_bias": self.w2_bias,
            "output1_scale_scalar": None,
            "output1_scale_gate_scalar": None,
            "output2_scale_scalar": None,
            "num_experts": global_num_experts,
            "top_k": topk,
            "n_group": None,
            "topk_group": None,
            "intermediate_size": intermediate_size,
            "local_expert_offset": local_expert_offset,
            "local_num_experts": local_num_experts,
            "routed_scaling_factor": None,
            "tile_tokens_dim": None,
            "routing_method_type": 1,
            "do_finalize": True,
            "output": output,
            "tune_max_num_tokens": self.max_capture_size,
        }

        from flashinfer import trtllm_fp4_block_scale_routed_moe

        trtllm_fp4_block_scale_routed_moe(**kwargs)
        return output

activation_formats property

gemm1_alpha instance-attribute

gemm1_alpha = gemm1_alpha

gemm1_beta instance-attribute

gemm1_beta = gemm1_beta

gemm1_clamp_limit instance-attribute

gemm1_clamp_limit = gemm1_clamp_limit

max_capture_size instance-attribute

max_capture_size = max_capture_size

moe instance-attribute

moe = moe

__init__

__init__(
    moe: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
    gemm1_alpha,
    gemm1_beta,
    gemm1_clamp_limit,
    max_capture_size,
)
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def __init__(
    self,
    moe: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
    gemm1_alpha,
    gemm1_beta,
    gemm1_clamp_limit,
    max_capture_size,
):
    super().__init__(quant_config)
    self.moe = moe
    self.gemm1_alpha = gemm1_alpha
    self.gemm1_beta = gemm1_beta
    self.gemm1_clamp_limit = gemm1_clamp_limit
    self.max_capture_size = max_capture_size

apply

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Tensor | None,
    a1q_scale: Tensor | None,
    a2_scale: Tensor | None,
    workspace13: Tensor,
    workspace2: Tensor,
    expert_tokens_meta: ExpertTokensMetadata | None,
    apply_router_weight_on_input: bool,
)
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: torch.Tensor | None,
    a1q_scale: torch.Tensor | None,
    a2_scale: torch.Tensor | None,
    workspace13: torch.Tensor,
    workspace2: torch.Tensor,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
    apply_router_weight_on_input: bool,
):
    topk = topk_ids.size(-1)
    local_num_experts = w1.size(0)
    intermediate_size = w2.size(1)
    local_expert_offset = self.moe.ep_rank * local_num_experts

    x_quant = hidden_states
    x_scale = a1q_scale
    if x_scale is not None:
        x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1)

    packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
        torch.bfloat16
    ).view(torch.int16)

    assert self.w1_scale is not None
    assert self.w2_scale is not None
    kwargs = {
        "topk_ids": packed_tensor,
        "routing_bias": None,
        "hidden_states": x_quant,
        "hidden_states_scale": x_scale,
        "gemm1_weights": w1,
        "gemm1_weights_scale": self.w1_scale,
        "gemm1_bias": self.w1_bias,
        "gemm1_alpha": self.gemm1_alpha,
        "gemm1_beta": self.gemm1_beta,
        "gemm1_clamp_limit": self.gemm1_clamp_limit,
        "gemm2_weights": w2,
        "gemm2_weights_scale": self.w2_scale,
        "gemm2_bias": self.w2_bias,
        "output1_scale_scalar": None,
        "output1_scale_gate_scalar": None,
        "output2_scale_scalar": None,
        "num_experts": global_num_experts,
        "top_k": topk,
        "n_group": None,
        "topk_group": None,
        "intermediate_size": intermediate_size,
        "local_expert_offset": local_expert_offset,
        "local_num_experts": local_num_experts,
        "routed_scaling_factor": None,
        "tile_tokens_dim": None,
        "routing_method_type": 1,
        "do_finalize": True,
        "output": output,
        "tune_max_num_tokens": self.max_capture_size,
    }

    from flashinfer import trtllm_fp4_block_scale_routed_moe

    trtllm_fp4_block_scale_routed_moe(**kwargs)
    return output

finalize_weight_and_reduce_impl

finalize_weight_and_reduce_impl() -> TopKWeightAndReduce
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
    return TopKWeightAndReduceNoOP()

supports_chunking

supports_chunking() -> bool
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def supports_chunking(self) -> bool:
    return True

supports_expert_map

supports_expert_map() -> bool
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def supports_expert_map(self) -> bool:
    return True

workspace_shapes

workspace_shapes(
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: ExpertTokensMetadata | None,
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...]
]
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def workspace_shapes(
    self,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
    # The workspaces for this implementation are managed by flashinfer.
    workspace1 = (0,)
    workspace2 = (0,)
    output = (M, K)
    return (workspace1, workspace2, output)