@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=["H", "BT", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def merge_16x16_to_64x64_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
)
p_A_22 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
)
p_A_33 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
)
p_A_44 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
)
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
# [16, 16]
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
for i in range(32 + 2, min(48, T - i_t * BT)):
b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32)
b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
for i in range(48 + 2, min(64, T - i_t * BT)):
b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48)
b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
b_Ai_11 += m_I
b_Ai_22 += m_I
b_Ai_33 += m_I
b_Ai_44 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
)
p_A_31 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
)
p_A_32 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
)
p_A_41 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
)
p_A_42 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
)
p_A_43 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
)
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
b_Ai_21 = -tl.dot(
tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
b_Ai_11,
input_precision=DOT_PRECISION,
)
b_Ai_32 = -tl.dot(
tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION),
b_Ai_22,
input_precision=DOT_PRECISION,
)
b_Ai_43 = -tl.dot(
tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION),
b_Ai_33,
input_precision=DOT_PRECISION,
)
b_Ai_31 = -tl.dot(
b_Ai_33,
tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION)
+ tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
b_Ai_42 = -tl.dot(
b_Ai_44,
tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION)
+ tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
b_Ai_41 = -tl.dot(
b_Ai_44,
tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION)
+ tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION)
+ tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
)
p_Ai_22 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
)
p_Ai_33 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
)
p_Ai_44 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
)
p_Ai_21 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
)
p_Ai_31 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
)
p_Ai_32 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
)
p_Ai_41 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
)
p_Ai_42 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
)
p_Ai_43 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
)
tl.store(
p_Ai_11,
b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_22,
b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_33,
b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_44,
b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_21,
b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_31,
b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_32,
b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_41,
b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_42,
b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_43,
b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
else:
desc_o.store(
[i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")
)