SoftMax cuda算子解读¶
参考链接
How to write a fast Softmax CUDA kernel?
softmax公式为
y_i = \frac{e^{x_i-C}}{\sum_i^N e^{x_i-C}}
其中C = \max x, 根据公式,其中有sum与max的操作,所以需要两次归约,第一次归约向量X得到最大值C,第二次归约向量e^{x_i-C}得到和
如果有多个线程块的话,还需要对不同线程块的归约结果进行归约,为了方便将kernel融合为一个,我们将只使用一个block计算softmax
线程块使用32个线程¶
由于线程块只有32个线程,所以可以只用线程束函数进行归约,但在这之前每个线程要先计算cols_per_thread的归约结果, cols_per_thread=N/32
template <typename T>
__inline__ __device__ T Inf();
template <>
__inline__ __device__ float Inf<float>() {
return CUDART_INF_F;
}
//束内归约
//method=0: 求最大值,method=1: 求和
template<int NUM>
inline __device__ void warpReduce(real *val, int method = 0)
{
for(int i = 0; i<NUM; i++)
{
for(int offset = 32>>1; offset>0; offset>>=1)
{
if(method==1)
val[i] = val[i]+__shfl_xor_sync(0xffffffff, val[i], offset);
else
val[i] = max(val[i], __shfl_xor_sync(0xffffffff, val[i], offset));
}
}
}
//一个 Warp 处理一行的计算,适用于 num_cols <= 1024 情况, block = <32, 4>, grid = <M/4>
template<int cols_per_thread>
__global__ void softmax_kernel(real* input, real* output, int M, int N)
{
constexpr int num_packs = (cols_per_thread+3) / 4;
int m_idx = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float4 buf[num_packs];//寄存器存num_packs*4个变量
for(int row = m_idx ; row<M; row+=gridDim.x*blockDim.y)
{
//当前操作的索引
const int row_idx = row*(N);
real* row_ptr = input+row_idx;
real* out_row_ptr = output+row_idx;
real local_max[1] = {-Inf<float>()};
//分块
for(int pack_id = 0; pack_id<num_packs; ++pack_id)
{
const int col = (pack_id*32+tid)*4;//重排了线程,使得相邻线程访问相邻四个元素,可以换一种思考方式,一个线程处理的元素是cols_per_thread,这是不连续的,每隔32个处理一次
if(col<N)
{
buf[pack_id] = {row_ptr[col], row_ptr[col+1], row_ptr[col+2], row_ptr[col+3]};
local_max[0] = max(local_max[0], max(max(buf[pack_id].x, buf[pack_id].y), max(buf[pack_id].z, buf[pack_id].w)));
}
else {
buf[pack_id].x = -Inf<float>();
buf[pack_id].y = -Inf<float>();
buf[pack_id].z = -Inf<float>();
buf[pack_id].w = -Inf<float>();
}
}
//此时local_max[0]是当前线程中cols_per_thread个值的最大值
warpReduce<1>(local_max);//32个线程束内归约,得到束内最大值,对于此函数也就是全局最大值,且每个线程这个值都一样
float local_sum[1] = {0.0f};//归约和
for(int pack_id = 0; pack_id<num_packs; ++pack_id)
{
buf[pack_id].x = exp(buf[pack_id].x - local_max[0]);
buf[pack_id].y = exp(buf[pack_id].y - local_max[0]);
buf[pack_id].z = exp(buf[pack_id].z - local_max[0]);
buf[pack_id].w = exp(buf[pack_id].w - local_max[0]);
local_sum[0] += buf[pack_id].x + buf[pack_id].y + buf[pack_id].z + buf[pack_id].w;
}
warpReduce<1>(local_sum, 1);//同上
for(int pack_id = 0; pack_id<num_packs; ++pack_id)
{
int col = (pack_id*32+tid)*4;
if(col<N)
{
out_row_ptr[col] = buf[pack_id].x/local_sum[0];
out_row_ptr[col+1] = buf[pack_id].y/local_sum[0];
out_row_ptr[col+2] = buf[pack_id].z/local_sum[0];
out_row_ptr[col+3] = buf[pack_id].w/local_sum[0];
}
}
}
}
线程块使用超过32个线程¶
如果线程块中超过x方向超过32个线程,使用束内归约只能得到每个warp的归约结果,还需要对每个warp的归约结果进行进一步归约, 代码如下
//线程块归约, blockDim.x > 32
template<int NUM>
inline __device__ void blockReduce(real *val, int method = 0)
{
__shared__ real s_val[NUM][32];//共享内存存每个warp的归约结果,线程数最大为1024, 所以这里最大为32
int lane = threadIdx.x&0x1f;//threadIdx.x % 32
int wid = threadIdx.x>>5;//threadIdx.x/32
warpReduce<NUM>(val, method);
//此时block中每个warp完成的归约,且在warp首个线程的寄存器中
if(lane==0)//找每个warp的第一个线程
{
for(int i = 0; i<NUM; i++)
s_val[i][wid] = val[i];//把该寄存器的值赋给共享内存
}
__syncthreads();
//将共享内存里的结果赋给寄存器
for(int i = 0; i<NUM; i++)
{
val[i] = (threadIdx.x<(blockDim.x/32.f))? s_val[i][lane]: 0.0;
}
if(wid==0)//此时第一个warp中的数据归约结果
{
warpReduce<NUM>(val,method);
}
}
template<int cols_per_thread, int block_size>
__global__ void softmax_block_kernel(const real* input, real* output, int M, int N)
{
constexpr int num_packs = (cols_per_thread+3) / 4;
int m_idx = blockIdx.x;
const int tid = threadIdx.x;
float4 buf[num_packs];//寄存器存num_packs*4个变量
for(int row = m_idx ; row<M; row+=gridDim.x)
{
//当前操作的索引
const int row_idx = row*(N);
const real* row_ptr = input+row_idx;
real* out_row_ptr = output+row_idx;
real local_max[1] = {-Inf<float>()};
//分块
for(int pack_id = 0; pack_id<num_packs; ++pack_id)
{
const int col = (pack_id*block_size+tid)*4;//重排了线程,使得相邻线程访问相邻元素,可以换一种思考方式,一个线程处理的元素是cols_per_thread,这是不连续的,每隔32个处理一次
if(col<N)
{
buf[pack_id] = {row_ptr[col], row_ptr[col+1], row_ptr[col+2], row_ptr[col+3]};
local_max[0] = max(local_max[0], max(max(buf[pack_id].x, buf[pack_id].y), max(buf[pack_id].z, buf[pack_id].w)));
}
else {
buf[pack_id].x = -Inf<float>();
buf[pack_id].y = -Inf<float>();
buf[pack_id].z = -Inf<float>();
buf[pack_id].w = -Inf<float>();
}
}
blockReduce<1>(local_max);//此时归约结果在第一个warp中
//将全局最大值赋给共享内存以便全部线程访问
__shared__ float global_max;
if(tid==0)
global_max = local_max[0];
__syncthreads();//让所有线程得到最大值, 也就是对于tid!=0的线程保证共享内存global_max已经被赋值
float local_sum[1] = {0.0f};//归约和
for(int pack_id = 0; pack_id<num_packs; ++pack_id)
{
buf[pack_id].x = exp(buf[pack_id].x - global_max);
buf[pack_id].y = exp(buf[pack_id].y - global_max);
buf[pack_id].z = exp(buf[pack_id].z - global_max);
buf[pack_id].w = exp(buf[pack_id].w - global_max);
local_sum[0] += buf[pack_id].x + buf[pack_id].y + buf[pack_id].z + buf[pack_id].w;
}
blockReduce<1>(local_sum, 1);//此时归约结果在第一个warp中
__shared__ float global_sum;//同上
if(tid==0)
global_sum = local_sum[0];
__syncthreads();
for(int pack_id = 0; pack_id<num_packs; ++pack_id)
{
int col = (pack_id*block_size+tid)*4;
if(col<N)
{
out_row_ptr[col] = buf[pack_id].x/global_sum;
out_row_ptr[col+1] = buf[pack_id].y/global_sum;
out_row_ptr[col+2] = buf[pack_id].z/global_sum;
out_row_ptr[col+3] = buf[pack_id].w/global_sum;
}
}
}
}
使用共享内存存数据¶
如果寄存器不能保存全部的数据,就需要使用共享内存了, 然后根据线程索引对共享内存进行访问
template<int block_size>
__global__ void softmax_shared_kernel(const real* input, real* output, int M, int N)
{
//blockDim.y = 1
int m_idx = blockIdx.x;
int tid = threadIdx.x;
extern __shared__ float buf[]; //大小为N*sizeof(real)
int num_packs = N>>2;//所有的线程的pack数量
// int num_packs = (N/block_size+3)/4;
for(int row = m_idx; row<M; row+=gridDim.x)
{
const int row_offset = row*N;
const real* row_ptr = input+row_offset;
real* out_row_ptr = output+row_offset;
real local_max[1] = {-Inf<float>()};
for(int pack_id = tid; pack_id<num_packs; pack_id+=block_size)
{
const int col = tid;
// if(col<N/4)
{
float4 pack = {row_ptr[col], row_ptr[col+1], row_ptr[col+2], row_ptr[col+3]};
//这样子存是为了防止bank冲突, 这样线程束中不同线程访问的不同的bank
buf[col] = pack.x;
buf[col+num_packs] = pack.y;
buf[col+2*num_packs] = pack.z;
buf[col+3*num_packs] = pack.w;
local_max[0] = max(local_max[0], max(max(pack.x, pack.y), max(pack.z, pack.w)));
}
}
blockReduce<1>(local_max);//此时规约结果在第一个warp中
__shared__ float global_max;
if(tid==0)
global_max = local_max[0];
__syncthreads();//让所有线程得到最大值, 也就是对于tid!=0的线程保证共享内存global_max已经被赋值
//求和
real local_sum[1] = {0.0f};
for(int pack_id = tid; pack_id<N; pack_id+=block_size)
{
float local_val = exp(buf[pack_id]-global_max);
buf[pack_id] = local_val;
local_sum[0] += local_val;
}
blockReduce<1>(local_sum, 1);//此时归约结果在第一个warp中
__shared__ float global_sum;
if(tid==0)
global_sum = local_sum[0];
__syncthreads();//同上
for(int pack_id=tid; pack_id<num_packs; pack_id+=block_size)
{
int col = pack_id;
out_row_ptr[col] = buf[pack_id]/global_sum;
out_row_ptr[col+1] = buf[pack_id+num_packs]/global_sum;
out_row_ptr[col+2] = buf[pack_id+2*num_packs]/global_sum;
out_row_ptr[col+3] = buf[pack_id+3*num_packs]/global_sum;
}
}
}
总结:
- 如果一个block线程只有32个,所有线程访问同一个值可以通过束内函数去访问寄存器的值,而不用共享内存
- 如果一个block线程超过32个,所有线程访问同一个值就需要使用共享内存,同时要注意同步线程
- 对分块有了更深刻的理解
- 共享内存的赋值可以任意分配,不必按照分块时方法分配,要保证优先避免bank冲突!! 同一个线程束中的不同线程尽量不访问同一个bank