Support cutlass Int8 gemm: https://github.com/sgl-project/sglang/pull/2752/files
Main file: sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
Main function: cutlass_int8_scaled_mm
template <
typename ElementOutput,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
int NumStages
>
void cutlass_int8_scaled_mm(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias)
{ ... }
int8_scaled_mm
: entry point
smXX_dispatch_shape
: dispatch based on GPU architecture
cutlass_int8_scaled_mm
: core kernel
Prepare pointers, strides, dimensions from input tensors
Allocate workspace
auto workspace = torch::empty(gemm_op.get_workspace_size(args),
torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
Launch kernel
auto status = gemm_op(args, workspace.data_ptr(), stream);
gemm_op
: CUTLASS execution