Intro


Support cutlass Int8 gemm: https://github.com/sgl-project/sglang/pull/2752/files

Main file: sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu

Main function: cutlass_int8_scaled_mm

template <
	typename ElementOutput,  
	typename ArchTag,          
	typename ThreadblockShape, 
	typename WarpShape, 
	typename InstructionShape, 
	int NumStages
>
void cutlass_int8_scaled_mm(
	torch::Tensor& out, 
	const torch::Tensor& mat_a, 
	const torch::Tensor& mat_b,
  const torch::Tensor& scales_a, 
  const torch::Tensor& scales_b,
  const c10::optional<torch::Tensor>& bias) 
{ ... }

Code Flow


  1. int8_scaled_mm: entry point

    1. Takes mat_a, mat_b, scales_a, scales_b, bias as input
  2. smXX_dispatch_shape: dispatch based on GPU architecture

    1. Select different cutlass_int8_scaled_mm configuration based on GPU architecture
  3. cutlass_int8_scaled_mm: core kernel

    1. Prepare pointers, strides, dimensions from input tensors

    2. Allocate workspace

      auto workspace = torch::empty(gemm_op.get_workspace_size(args),
      torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
      
    3. Launch kernel

      auto status = gemm_op(args, workspace.data_ptr(), stream);
      
  4. gemm_op: CUTLASS execution

Variable Definitions