@torch.inference_mode()
def _fused_moe_lora(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
lora_a_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, N, max_lora_rank,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
max_lora_rank: int,
top_k_num: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
mul_routed_weight: bool = False,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert (
sorted_token_ids.dim()
== expert_ids.dim()
== topk_weights.dim()
== qcurr_hidden_states.dim()
== 2
)
assert (
sorted_token_ids.shape[0]
== expert_ids.shape[0]
== num_tokens_post_padded.shape[0]
)
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked)
config = {
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
}
w1_lora_a_stacked = lora_a_stacked[0]
w1_lora_b_stacked = lora_b_stacked[0]
num_experts = lora_a_stacked[0].shape[1]
N = max_lora_rank
M = topk_weights.shape[0]
EM = sorted_token_ids.shape[1]
K = qcurr_hidden_states.shape[1]
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
lora_intermediate_cache1 = torch.zeros(
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
dtype=torch.bfloat16,
device=device,
)
# slices
a_intermediate_size = num_slices * M * top_k_num * max_lora_rank
a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view(
num_slices, M, top_k_num, max_lora_rank
)
b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view(
num_slices, M, top_k_num, w1_output_dim_size
)
b_ptr = _get_ptr(lora_a_stacked, device)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
lora_a_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
b_ptr,
a_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2),
a_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=1,
num_slice_c=num_slices,
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
**config,
)
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
N = w1_output_dim_size
# a_intermediate_cache1 = a_intermediate_cache1.view(
# M, -1, a_intermediate_cache1.shape[3]
# )
a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
lora_b_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
b_ptr,
b_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
w1_lora_b_stacked.stride(1),
w1_lora_b_stacked.stride(3),
w1_lora_b_stacked.stride(2),
b_intermediate_cache1.stride(2),
b_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=num_slices,
num_slice_c=num_slices,
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=b_intermediate_cache1.numel() // num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
**config,
)
for i in range(num_slices):
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]