我们将基于CUTLASS basic_gemm example 0(https://github.com/NVIDIA/cutlass/blob/main/examples/00_basic_gemm/basic_gemm.cu)实现。对于熟悉CUTLASS的人,请注意这个例子使用了2.X语法。我们还将在本文的附录中提供一个使用3.X语法针对NVIDIA Hopper™架构的单独示例。
template<typename DataType, typename OutputType> voidcutlass_gemm_unpack(torch::Tensor A, torch::Tensor B, torch::Tensor C){ // Get data sizes constint M = A.sizes()[0]; constint K = B.sizes()[0]; constint N = B.sizes()[1];
// Casting to the data type of the input tensor DataType const *ptrA = reinterpret_cast(A.data_ptr()); DataType const *ptrB = reinterpret_cast(B.data_ptr()); DataType *ptrC = reinterpret_cast(C.data_ptr()); cutlass_gemm_wrapper(M, N, K, ptrA, ptrB, ptrC); } ``
与PyTorch的mm类似,我们的函数将会把C张量返回给PyTorch以供使用。我们还需要更新函数参数以将C标记为可选。Torch C++ API提供了一个工具c10::optional<:tensor>来指定Tensor参数为可选。有了这个,我们可以用.has_value()方法检查是否提供了输入。如果这返回true,我们就可以用.value()方法获取值。
torch::Tensor cutlass_gemm(torch::Tensor A, torch::Tensor B, c10::optional<:tensor> out){
// Handling the optional C matrix torch::Tensor C; if(out.has_value()) { // Output tensor was provided. So we will use it. C = out.value(); } else { // Output tensor was not provided. Creating an empty tensor. constint M = A.sizes()[0]; constint N = B.sizes()[1];
// We will allocate the matrix on GPU and set the datatype to be the same as the input auto c_options = torch::TensorOptions().device(torch::kCUDA).dtype(A.dtype()); C = torch::empty({M, N}, c_options); }
template<typename DataType, typename OutputType> voidcutlass_gemm_unpack(torch::Tensor A, torch::Tensor B, torch::Tensor C){ // Get data sizes constint M = A.sizes()[0]; constint K = B.sizes()[0]; constint N = B.sizes()[1];
// Casting to the data type of the input tensor DataType const *ptrA = reinterpret_cast(A.data_ptr()); DataType const *ptrB = reinterpret_cast(B.data_ptr()); DataType *ptrC = reinterpret_cast(C.data_ptr()); cutlass_gemm_wrapper(M, N, K, ptrA, ptrB, ptrC); }
// Intermediate function to get the output precision to use for the wrapper template. template<typename DataType> voidcutlass_gemm_find_output_type(torch::Tensor A, torch::Tensor B, torch::Tensor C){ if(C.dtype() == torch::kFloat16) cutlass_gemm_unpackhalf_t>(A, B, C); elseif(C.dtype() == torch::kFloat32) cutlass_gemm_unpackfloat>(A, B, C); else throwstd::invalid_argument("Unsupported precision type"); }
// This function is bound to "cutlass_gemm.mm". Takes torch::Tensors as inputs torch::Tensor cutlass_gemm(torch::Tensor A, // A matrix (m x k) torch::Tensor B, // B matrix (k x n) c10::optional<:tensor> out){ // optional out matrix (m x n) // Handling the optional C matrix torch::Tensor C; if(out.has_value()) { // Output tensor was provided. So we will use it. C = out.value(); } else { // Output tensor was not provided. Creating an empty tensor. constint M = A.sizes()[0]; constint N = B.sizes()[1]; // We will allocate the matrix on GPU and set the datatype to be the same as the input auto c_options = torch::TensorOptions().device(torch::kCUDA).dtype(A.dtype()); C = torch::empty({M, N}, c_options); }
// Check that all tensors are allocated on GPU device. if(!(A.device().is_cuda() && B.device().is_cuda() && C.device().is_cuda())) throwstd::invalid_argument("cutlass_gemm only supports GPU device. Use .to(device=torch.device('cuda'))");
// Ensuring that the matrices are contiguous. torch::Tensor _A = A.contiguous(); torch::Tensor _B = B.contiguous(); torch::Tensor _C = C.contiguous();
// Select the CUTLASS precision type to use based on Torch input data type. if(A.dtype() == torch::kFloat16) cutlass_gemm_find_output_type<: style="color: #c678dd;line-height: 26px;">half_t>(_A, _B, _C); elseif(A.dtype() == torch::kFloat32) cutlass_gemm_find_output_type<float>(_A, _B, _C); else throwstd::invalid_argument("Unsupported precision type");
// If C was not contiguous, C != _C so copy the result back into C if(!C.is_contiguous()) C.copy_(_C);
// Return the Torch tensor back to PyTorch return C; }
M = K = N = 4096 cuda = torch.device('cuda') A = torch.normal(0,1,size=(M, K)).to(device=cuda).to(dtype=torch.float16)/math.sqrt(K) B = torch.normal(0,1,size=(K, N)).to(device=cuda).to(dtype=torch.float16)/math.sqrt(K)
// CUTLASS 2.X syntax GEMM // Adapted from https://github.com/NVIDIA/cutlass/blob/main/examples/00_basic_gemm/basic_gemm.cu
#include
template<typename DataType, typename OutputType> voidcutlass_gemm_wrapper(int M, int N, int K, DataType const* ptrA, DataType const* ptrB, OutputType* ptrC){ using Gemm = cutlass::gemm::device::Gemm< DataType, // ElementA cutlass::layout::RowMajor, // LayoutA DataType, // ElementB cutlass::layout::RowMajor, // LayoutB OutputType, // ElementOutput cutlass::layout::RowMajor, // LayoutOutput float// ElementAccumulator >;
float alpha = 1.0f; float beta = 0.0f;
int lda = M; int ldb = K; int ldc = M;
Gemm gemm_op; gemm_op({ {M, N, K}, {ptrA, lda}, // TensorRef to A device tensor {ptrB, ldb}, // TensorRef to B device tensor {ptrC, ldc}, // TensorRef to C device tensor {ptrC, ldc}, // TensorRef to D device tensor - may be the same as C {alpha, beta} // epilogue operation arguments }); }
#else
// CUTLASS 3.X syntax GEMM // Adapted from https://github.com/NVIDIA/cutlass/blob/main/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
template<typename DataType, typename OutputType> voidcutlass_gemm_wrapper(int M, int N, int K, DataType const* ptrA, DataType const* ptrB, OutputType* ptrC){
// A matrix configuration using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand constexprint AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand constexprint AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands constexprint AlignmentC = 128 / cutlass::sizeof_bits::value;
// Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using TilesShape = Shape<_128>; // Threadblock-level tile size using ClusterShape = Shape<_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
// // Launch GEMM on the device // typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {ptrA, stride_A, ptrB, stride_B}, {{alpha, beta}, ptrC, stride_C, ptrC, stride_D} };
// Using the arguments, query for extra workspace required for matrix multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments);