来源丨https://zhuanlan.zhihu.com/p/670656687
本文是FasterTransformer Decoding 源码分析的第五篇,主要介绍FasterTransformer中融合OP AddBiasResidualLayerNorm是如何实现及优化的。融合OP包含了LayerNorm、AddBias和AddResidual三个算子,其中LayerNorm的实现和分析已经在进击的Killua:FasterTransformer Decoding 源码分析(三)-LayerNorm介绍这篇文章中详解过,剩下两个操作非常简单,本质上就是两个矩阵求和,所以这篇文章内容会很短,主要是介绍融合的这个思想。
背景知识
三个OP分别是LayerNorm、AddResidual和AddBias,下面分别简单介绍下他们的出处。
LayerNorm
LayerNorm(层归一化)是一种用于深度神经网络中的归一化技术。它可以对网络中的每个神经元的输出进行归一化,使得网络中每一层的输出都具有相似的分布,公式为如下所示。
LayerNorm 计算公式
AddResidual
残差设计源自于resnet,是将模块处理前的输入和模块处理后的输出进行相加再作为输出进入后续流程中。引入残差块可以解决深度神经网络训练过程中的梯度消失和梯度爆炸问题,让模型可以更深,收敛更快,可以通过文章7.6. 残差网络(ResNet) - 动手学深度学习 2.0.0 documentation学习。
添加残差
AddBias
添加偏置项来自于线性变换,线性变换的前半部是矩阵乘法,通常会通过cuBLAS中的gemm函数完成,而后半部分和偏置项求和往往会单独调用kernel运算,这里和前两个OP进行了融合,减少了kernel调用。
线性变换
源码分析
入口函数
函数签名如下,这里把三个OP需要的输入和参数都包含进去了。
template< typename T> void invokeGeneralAddBiasResidualPreLayerNorm(T* output, // 添加bias和residual输出 T* norm_output, // 整体正则化输出 const T* input, // 输入 const T* residual1, // 残差1 const T* residual2, // 残差2 const T* gamma, const T* beta, const T* bias, // 偏置 const float layernorm_eps, int m, // block数量 int n, // 处理的每一行数据元素个数 const float * scale_inter, const float * scale_out, float * scale, float * dynamic_scale, const int int8_mode, cudaStream_t stream, int opt_version = 2 );
调用kernel
这里gridSize等于一批处理的数据个数,即一个block处理输入的一行数据,符合并行处理的思路。blockSize是一份数据的维度和1024的较小值,可以理解,大多数CUDA设备一个block支持的最大线程数就是1024,所以这里要min处理下。这里还有个trick就是维度如果不是32的倍数就也设置为1024,主要是为了最大化利用warp(32个线程)特性来处理数据。动态量化的部分我们先跳过,接下来就是调用函数进入到kernel实现部分。
dim3 grid (m); dim3 block (min(n, 1024 )); /* For general cases, n is equal to hidden_units, e.g., 512/1024. Since we have warp shuffle inside the code, block.x % 32 should be 0. */ block.x = (block.x + 31 ) / 32 * 32 ; size_t maxbytes = n * sizeof (T); if (residual_num == 1 ) { if (maxbytes >= (48 << 10 )) { check_cuda_error(cudaFuncSetAttribute( generalAddBiasResidualLayerNorm< T, 1 > , cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes)); } generalAddBiasResidualLayerNorm< T, 1 ><<< grid, block, maxbytes, stream>>> (input, residual1, residual2, gamma, beta, bias, output, norm_output, layernorm_eps, m, n, scale_inter, scale_out, scale, dynamic_scale, int8_mode); }
kernel实现
这里为了代码结构更加清晰先将量化相关的代码先去掉了,整个流程还是比较容易理解,先进行了残差项、输入项和偏置项的求和合并,结果存到共享内存,再通过两次block级别的归约实现了LayerNorm,从而实现了算子融合,详细流程如注释所展示。
block模型上,一个block处理一行数据(n维度),block中有m个线程,1个线程可能处理1到多个数据中的元素,如下图所示,这里n=8,m=4,所以一个线程需要处理2个数据,反映到代码中就是单个线程对2个元素进行本地三项求和、归约求和和差值平方。
template< typename T, int RESIDUAL_NUM> __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, const T* __restrict residual1, const T* __restrict residual2, const T* __restrict gamma, const T* __restrict beta, const T* __restrict bias, T* output, T* norm_output, const float layernorm_eps, int m, int n, const float * scale_inter, const float * scale_out, const float * scale, float * dynamic_scale, const int int8_mode) { int tid = threadIdx.x; // NOTE: float shmem may exceed the shared memory limit // 使用共享内存存储中间变量 extern __shared__ __align__ (sizeof (float )) char _shmem[]; T* shmem = reinterpret_cast< T*> (_shmem); // 定义共享内存变量:均值和方差 __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f ; float variance = 0.0f ; float local_sum = 0.0f ; // blockDim跨度的遍历,block中线程数小于数据维度数时一个线程可能处理多个element for (int i = tid; i < n; i += blockDim.x) { float local_out = 0.0f ; // 提取残差项 if (RESIDUAL_NUM == 1 ) { local_out = (float )(ldg(& residual1[blockIdx.x * n + i])); } else if (RESIDUAL_NUM == 2 ) { local_out = (float )(ldg(& residual1[blockIdx.x * n + i])) + float (ldg(& residual2[blockIdx.x * n + i])); } // 残差项和输入项合并