Skip to content

vllm.lora.layers.fused_moe

FusedMoEWithLoRA

Bases: BaseLayerWithLoRA

Source code in vllm/lora/layers/fused_moe.py
class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: FusedMoE) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.device = base_layer.w2_weight.device
        self._inject_lora_into_fused_moe()

    def _inject_lora_into_fused_moe(self):
        moe_state_dict = {}
        top_k = self.base_layer.top_k

        if self.base_layer.quant_config is None:
            quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
        elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
            quant_config = self.base_layer.quant_config
        else:
            quant_config = mxfp4_w4a16_moe_quant_config(
                w1_bias=self.base_layer.w13_bias,
                w2_bias=self.base_layer.w2_bias,
                w1_scale=self.base_layer.w13_weight_scale,
                w2_scale=self.base_layer.w2_weight_scale,
            )

        m_fused_moe_fn = (
            modular_triton_fused_moe(
                quant_config, shared_experts=self.base_layer.shared_experts
            )
            if not quant_config.use_mxfp4_w4a16
            else modular_marlin_fused_moe(
                quant_config, shared_experts=self.base_layer.shared_experts
            )
        )

        def fwd_decorator(layer, func):
            def wrapper(*args, **kwargs):
                moe_state_dict["hidden_states"] = kwargs["hidden_states"]
                moe_state_dict["topk_ids"] = kwargs["topk_ids"]
                moe_state_dict["topk_weights"] = kwargs["topk_weights"]
                moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
                moe_state_dict["expert_map"] = kwargs["expert_map"]
                moe_state_dict["apply_router_weight_on_input"] = kwargs[
                    "apply_router_weight_on_input"
                ]
                moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
                result = func(*args, **kwargs)
                return result

            return wrapper

        def act_decorator(layer, func):
            def wrapper(*args, **kwargs):
                _, output, input = args

                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]
                curr_topk_ids = moe_state_dict["topk_ids"]
                global_num_experts = moe_state_dict["global_num_experts"]
                expert_map = moe_state_dict["expert_map"]
                max_loras = moe_state_dict["max_loras"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)

                get_config_func = functools.partial(
                    try_get_optimal_moe_config,
                    layer.w13_weight.size(),
                    layer.w2_weight.size(),
                    top_k,
                    config_dtype,
                    block_shape=layer.quant_method.moe_quant_config.block_shape,
                )

                config = get_config_func(M)
                (
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                ) = self.punica_wrapper.moe_lora_align_block_size(
                    curr_topk_ids,
                    num_tokens,
                    config["BLOCK_SIZE_M"],
                    global_num_experts,
                    max_loras,
                    expert_map,
                )

                moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
                moe_state_dict["expert_ids_lora"] = expert_ids_lora
                moe_state_dict["num_tokens_post_padded_lora"] = (
                    num_tokens_post_padded_lora
                )

                w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
                w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
                max_lora_rank = self.w1_lora_a_stacked.shape[-2]
                expert_ids_lora = expert_ids_lora.view(max_loras, -1)
                sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)

                self.punica_wrapper.add_lora_fused_moe(
                    input.view(-1, top_k, input.shape[-1]),
                    hidden_states,
                    w13_lora_a_stacked,
                    w13_lora_b_stacked,
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
                    config,
                )

                result = func(*args, **kwargs)

                moe_state_dict["intermediate_cache2"] = output
                return result

            return wrapper

        def moe_sum_decorator(layer, func):
            def wrapper(*args, **kwargs):
                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]
                max_loras = moe_state_dict["max_loras"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)

                get_config_func = functools.partial(
                    try_get_optimal_moe_config,
                    layer.w13_weight.size(),
                    layer.w2_weight.size(),
                    top_k,
                    config_dtype,
                    block_shape=layer.quant_method.moe_quant_config.block_shape,
                )

                config = get_config_func(M)

                sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
                expert_ids_lora = moe_state_dict["expert_ids_lora"]
                num_tokens_post_padded_lora = moe_state_dict[
                    "num_tokens_post_padded_lora"
                ]

                expert_ids_lora = expert_ids_lora.view(max_loras, -1)
                sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
                intermediate_cache2 = moe_state_dict["intermediate_cache2"]
                intermediate_cache3 = args[0]
                max_lora_rank = self.w1_lora_a_stacked.shape[-2]
                self.punica_wrapper.add_lora_fused_moe(
                    intermediate_cache3,
                    intermediate_cache2,
                    [self.w2_lora_a_stacked],
                    [self.w2_lora_b_stacked],
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
                    config,
                    True,
                )

                result = func(*args, **kwargs)
                return result

            return wrapper

        fused_experts = m_fused_moe_fn.fused_experts

        m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
        fused_experts.activation = act_decorator(
            self.base_layer, fused_experts.activation
        )
        fused_experts.moe_sum = moe_sum_decorator(
            self.base_layer, fused_experts.moe_sum
        )

        self.base_layer.quant_method.old_fused_experts = (
            self.base_layer.quant_method.fused_experts
        )
        self.base_layer.quant_method.fused_experts = m_fused_moe_fn

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        assert not self.base_layer.use_ep, (
            "EP support for Fused MoE LoRA is not implemented yet."
        )

        self.w1_lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                lora_config.max_lora_rank,
                self.base_layer.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.w1_lora_b_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                self.base_layer.intermediate_size_per_partition,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        self.w2_lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                lora_config.max_lora_rank,
                self.base_layer.intermediate_size_per_partition,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.w2_lora_b_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                self.base_layer.hidden_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        self.w3_lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                lora_config.max_lora_rank,
                self.base_layer.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.w3_lora_b_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                self.base_layer.intermediate_size_per_partition,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
        self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
        self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
        self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
        self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
        self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
            for experts_id in range(self.base_layer.global_num_experts):
                # gate_proj,down_proj,up_proj
                self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
                self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
                self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])

                self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
                self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
                self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        self.w1_lora_a_stacked[index] = 0
        self.w1_lora_b_stacked[index] = 0
        self.w3_lora_a_stacked[index] = 0
        self.w3_lora_b_stacked[index] = 0
        self.w2_lora_a_stacked[index] = 0
        self.w2_lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: torch.Tensor | None,
        bias: torch.Tensor | None = None,
    ):
        """Overwrites lora tensors at index."""
        for eid in range(len(lora_a) // 3):
            w1_lora_a = lora_a[eid * 3]
            w2_lora_a = lora_a[eid * 3 + 1]
            w3_lora_a = lora_a[eid * 3 + 2]
            w1_lora_b = lora_b[eid * 3]
            w2_lora_b = lora_b[eid * 3 + 1]
            w3_lora_b = lora_b[eid * 3 + 2]

            if self.tp_size > 1:
                shard_size = self.base_layer.intermediate_size_per_partition
                start_idx = self.tp_rank * shard_size
                end_idx = (self.tp_rank + 1) * shard_size

                w1_lora_b = w1_lora_b[start_idx:end_idx, :]
                w3_lora_b = w3_lora_b[start_idx:end_idx, :]
                w2_lora_a = w2_lora_a[:, start_idx:end_idx]

            self.w1_lora_a_stacked[
                index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
            ].copy_(w1_lora_a, non_blocking=True)

            self.w3_lora_a_stacked[
                index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
            ].copy_(w3_lora_a, non_blocking=True)

            self.w2_lora_b_stacked[
                index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
            ].copy_(w2_lora_b, non_blocking=True)

            self.w1_lora_b_stacked[
                index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
            ].copy_(w1_lora_b, non_blocking=True)
            self.w3_lora_b_stacked[
                index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
            ].copy_(w3_lora_b, non_blocking=True)
            self.w2_lora_a_stacked[
                index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
            ].copy_(w2_lora_a, non_blocking=True)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # return type(source_layer) is FusedMoE
        return isinstance(source_layer, FusedMoE)

    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
        return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)

    @property
    def _shared_experts(self):
        return self.base_layer._shared_experts

    @property
    def quant_method(self):
        return self.base_layer.quant_method

_shared_experts property

_shared_experts

base_layer instance-attribute

base_layer = base_layer

device instance-attribute

device = device

quant_method property

quant_method

tp_rank instance-attribute

tp_size instance-attribute

__init__

__init__(base_layer: FusedMoE) -> None
Source code in vllm/lora/layers/fused_moe.py
def __init__(self, base_layer: FusedMoE) -> None:
    super().__init__()
    self.base_layer = base_layer
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.device = base_layer.w2_weight.device
    self._inject_lora_into_fused_moe()

_inject_lora_into_fused_moe

_inject_lora_into_fused_moe()
Source code in vllm/lora/layers/fused_moe.py
def _inject_lora_into_fused_moe(self):
    moe_state_dict = {}
    top_k = self.base_layer.top_k

    if self.base_layer.quant_config is None:
        quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
    elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
        quant_config = self.base_layer.quant_config
    else:
        quant_config = mxfp4_w4a16_moe_quant_config(
            w1_bias=self.base_layer.w13_bias,
            w2_bias=self.base_layer.w2_bias,
            w1_scale=self.base_layer.w13_weight_scale,
            w2_scale=self.base_layer.w2_weight_scale,
        )

    m_fused_moe_fn = (
        modular_triton_fused_moe(
            quant_config, shared_experts=self.base_layer.shared_experts
        )
        if not quant_config.use_mxfp4_w4a16
        else modular_marlin_fused_moe(
            quant_config, shared_experts=self.base_layer.shared_experts
        )
    )

    def fwd_decorator(layer, func):
        def wrapper(*args, **kwargs):
            moe_state_dict["hidden_states"] = kwargs["hidden_states"]
            moe_state_dict["topk_ids"] = kwargs["topk_ids"]
            moe_state_dict["topk_weights"] = kwargs["topk_weights"]
            moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
            moe_state_dict["expert_map"] = kwargs["expert_map"]
            moe_state_dict["apply_router_weight_on_input"] = kwargs[
                "apply_router_weight_on_input"
            ]
            moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
            result = func(*args, **kwargs)
            return result

        return wrapper

    def act_decorator(layer, func):
        def wrapper(*args, **kwargs):
            _, output, input = args

            hidden_states = moe_state_dict["hidden_states"]
            topk_weights = moe_state_dict["topk_weights"]
            curr_topk_ids = moe_state_dict["topk_ids"]
            global_num_experts = moe_state_dict["global_num_experts"]
            expert_map = moe_state_dict["expert_map"]
            max_loras = moe_state_dict["max_loras"]

            config_dtype = _get_config_dtype_str(
                dtype=hidden_states.dtype,
                use_fp8_w8a8=False,
                use_int8_w8a16=False,
                use_int4_w4a16=False,
            )
            CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
            num_tokens = hidden_states.size(0)
            M = min(num_tokens, CHUNK_SIZE)

            get_config_func = functools.partial(
                try_get_optimal_moe_config,
                layer.w13_weight.size(),
                layer.w2_weight.size(),
                top_k,
                config_dtype,
                block_shape=layer.quant_method.moe_quant_config.block_shape,
            )

            config = get_config_func(M)
            (
                sorted_token_ids_lora,
                expert_ids_lora,
                num_tokens_post_padded_lora,
            ) = self.punica_wrapper.moe_lora_align_block_size(
                curr_topk_ids,
                num_tokens,
                config["BLOCK_SIZE_M"],
                global_num_experts,
                max_loras,
                expert_map,
            )

            moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
            moe_state_dict["expert_ids_lora"] = expert_ids_lora
            moe_state_dict["num_tokens_post_padded_lora"] = (
                num_tokens_post_padded_lora
            )

            w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
            w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
            max_lora_rank = self.w1_lora_a_stacked.shape[-2]
            expert_ids_lora = expert_ids_lora.view(max_loras, -1)
            sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)

            self.punica_wrapper.add_lora_fused_moe(
                input.view(-1, top_k, input.shape[-1]),
                hidden_states,
                w13_lora_a_stacked,
                w13_lora_b_stacked,
                topk_weights,
                sorted_token_ids_lora,
                expert_ids_lora,
                num_tokens_post_padded_lora,
                max_lora_rank,
                top_k,
                config,
            )

            result = func(*args, **kwargs)

            moe_state_dict["intermediate_cache2"] = output
            return result

        return wrapper

    def moe_sum_decorator(layer, func):
        def wrapper(*args, **kwargs):
            hidden_states = moe_state_dict["hidden_states"]
            topk_weights = moe_state_dict["topk_weights"]
            max_loras = moe_state_dict["max_loras"]

            config_dtype = _get_config_dtype_str(
                dtype=hidden_states.dtype,
                use_fp8_w8a8=False,
                use_int8_w8a16=False,
                use_int4_w4a16=False,
            )
            CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
            num_tokens = hidden_states.size(0)
            M = min(num_tokens, CHUNK_SIZE)

            get_config_func = functools.partial(
                try_get_optimal_moe_config,
                layer.w13_weight.size(),
                layer.w2_weight.size(),
                top_k,
                config_dtype,
                block_shape=layer.quant_method.moe_quant_config.block_shape,
            )

            config = get_config_func(M)

            sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
            expert_ids_lora = moe_state_dict["expert_ids_lora"]
            num_tokens_post_padded_lora = moe_state_dict[
                "num_tokens_post_padded_lora"
            ]

            expert_ids_lora = expert_ids_lora.view(max_loras, -1)
            sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
            intermediate_cache2 = moe_state_dict["intermediate_cache2"]
            intermediate_cache3 = args[0]
            max_lora_rank = self.w1_lora_a_stacked.shape[-2]
            self.punica_wrapper.add_lora_fused_moe(
                intermediate_cache3,
                intermediate_cache2,
                [self.w2_lora_a_stacked],
                [self.w2_lora_b_stacked],
                topk_weights,
                sorted_token_ids_lora,
                expert_ids_lora,
                num_tokens_post_padded_lora,
                max_lora_rank,
                top_k,
                config,
                True,
            )

            result = func(*args, **kwargs)
            return result

        return wrapper

    fused_experts = m_fused_moe_fn.fused_experts

    m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
    fused_experts.activation = act_decorator(
        self.base_layer, fused_experts.activation
    )
    fused_experts.moe_sum = moe_sum_decorator(
        self.base_layer, fused_experts.moe_sum
    )

    self.base_layer.quant_method.old_fused_experts = (
        self.base_layer.quant_method.fused_experts
    )
    self.base_layer.quant_method.fused_experts = m_fused_moe_fn

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    # return type(source_layer) is FusedMoE
    return isinstance(source_layer, FusedMoE)

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    assert not self.base_layer.use_ep, (
        "EP support for Fused MoE LoRA is not implemented yet."
    )

    self.w1_lora_a_stacked = torch.zeros(
        (
            max_loras,
            self.base_layer.global_num_experts,
            lora_config.max_lora_rank,
            self.base_layer.hidden_size,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )
    self.w1_lora_b_stacked = torch.zeros(
        (
            max_loras,
            self.base_layer.global_num_experts,
            self.base_layer.intermediate_size_per_partition,
            lora_config.max_lora_rank,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )

    self.w2_lora_a_stacked = torch.zeros(
        (
            max_loras,
            self.base_layer.global_num_experts,
            lora_config.max_lora_rank,
            self.base_layer.intermediate_size_per_partition,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )
    self.w2_lora_b_stacked = torch.zeros(
        (
            max_loras,
            self.base_layer.global_num_experts,
            self.base_layer.hidden_size,
            lora_config.max_lora_rank,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )

    self.w3_lora_a_stacked = torch.zeros(
        (
            max_loras,
            self.base_layer.global_num_experts,
            lora_config.max_lora_rank,
            self.base_layer.hidden_size,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )
    self.w3_lora_b_stacked = torch.zeros(
        (
            max_loras,
            self.base_layer.global_num_experts,
            self.base_layer.intermediate_size_per_partition,
            lora_config.max_lora_rank,
        ),
        dtype=lora_config.lora_dtype,
        device=self.device,
    )

    self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
    self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
    self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
    self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
    self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
    self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
    # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
    # to create a dummy LoRA weights.
    self.lora_a_stacked = []
    self.lora_b_stacked = []
    for lora_id in range(max_loras):
        for experts_id in range(self.base_layer.global_num_experts):
            # gate_proj,down_proj,up_proj
            self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
            self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
            self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])

            self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
            self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
            self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])

forward

forward(*args, **kwargs)
Source code in vllm/lora/layers/fused_moe.py
def forward(self, *args, **kwargs):
    return self.base_layer.forward(*args, **kwargs)

maybe_all_reduce_tensor_model_parallel

maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
Source code in vllm/lora/layers/fused_moe.py
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
    return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)

reset_lora

reset_lora(index: int)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/fused_moe.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    self.w1_lora_a_stacked[index] = 0
    self.w1_lora_b_stacked[index] = 0
    self.w3_lora_a_stacked[index] = 0
    self.w3_lora_b_stacked[index] = 0
    self.w2_lora_a_stacked[index] = 0
    self.w2_lora_b_stacked[index] = 0

set_lora

set_lora(
    index: int,
    lora_a: Tensor,
    lora_b: Tensor,
    embeddings_tensor: Tensor | None,
    bias: Tensor | None = None,
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor,
    lora_b: torch.Tensor,
    embeddings_tensor: torch.Tensor | None,
    bias: torch.Tensor | None = None,
):
    """Overwrites lora tensors at index."""
    for eid in range(len(lora_a) // 3):
        w1_lora_a = lora_a[eid * 3]
        w2_lora_a = lora_a[eid * 3 + 1]
        w3_lora_a = lora_a[eid * 3 + 2]
        w1_lora_b = lora_b[eid * 3]
        w2_lora_b = lora_b[eid * 3 + 1]
        w3_lora_b = lora_b[eid * 3 + 2]

        if self.tp_size > 1:
            shard_size = self.base_layer.intermediate_size_per_partition
            start_idx = self.tp_rank * shard_size
            end_idx = (self.tp_rank + 1) * shard_size

            w1_lora_b = w1_lora_b[start_idx:end_idx, :]
            w3_lora_b = w3_lora_b[start_idx:end_idx, :]
            w2_lora_a = w2_lora_a[:, start_idx:end_idx]

        self.w1_lora_a_stacked[
            index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
        ].copy_(w1_lora_a, non_blocking=True)

        self.w3_lora_a_stacked[
            index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
        ].copy_(w3_lora_a, non_blocking=True)

        self.w2_lora_b_stacked[
            index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
        ].copy_(w2_lora_b, non_blocking=True)

        self.w1_lora_b_stacked[
            index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
        ].copy_(w1_lora_b, non_blocking=True)
        self.w3_lora_b_stacked[
            index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
        ].copy_(w3_lora_b, non_blocking=True)
        self.w2_lora_a_stacked[
            index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
        ].copy_(w2_lora_a, non_blocking=True)