我试图实现一个cuda内核的操作ReduceL2的教育目的。例如,我有一个形状为400*500*256
的三维Tensor,它应该沿着第一和第二轴被'l2'约化。根据定义,结果是1*1*256
Tensor。下面对numpy库的调用演示了操作符正在做什么:
reduced = np.sqrt(np.sum(a=np.square(input), axis=tuple((0,1)), keepdims=1))
字符串
为了在cuda内核中实现它,我首先在python中实现了它,如下所示:
reduce_first = np.zeros([500, 256], dtype=np.float32)
reduce_second = np.zeros([256], dtype=np.float32)
res *= res
for i in range(res.shape[1]): #500
for j in range(res.shape[2]): #256
sum = 0
for k in range(res.shape[0]): #400
sum += res[k][i][j]
reduce_first[i][j] = sum
#print(reduce_first.shape) #(500, 256)
for i in range(reduce_first.shape[1]): #256
sum = 0
for j in range(reduce_first.shape[0]): #500
sum += reduce_first[j][i]
reduce_second[i] = np.sqrt(sum)
型reduce_second
Tensor被验证为与上述reduced
Tensor相同。然后,我将python代码简单地翻译为下面的cuda内核
__global__ void ReduceL2Kernel(float* out, float* in, unsigned int first_dim, unsigned int second_dim, unsigned int third_dim){
//400 * 500 * 256 -> 1 * 1 * 256
//idz * idy * idx
const int idz = blockIdx.z;//0~399
const int idy = blockIdx.y * blockDim.y + threadIdx.y;//0~499
const int idx = blockIdx.x * blockDim.x + threadIdx.x;//0~255
if(idz >= first_dim || idy >= second_dim || idx >= third_dim) return;
const int index = (idz * second_dim + idy) * third_dim + idx; //400*500*256
in[index] *= in[index];
//reduce along '400' axis
float sum_first = 0.f;
for(unsigned int i = 0; i < first_dim; i++)
sum_first += in[(i * second_dim + idy) * third_dim + idx];
out[idy * third_dim + idx] = sum_first;
__threadfence();
//reduce along '500' axis
float sum_second = 0.f;
for(unsigned int i = 0; i < second_dim; i++)
sum_second += out[i * third_dim + idx];
out[idx] = sqrtf(sum_second);
}
型
下面是调用CUDA内核的CPP端代码
inline int CeilDiv(int a, int b) { return (a + b - 1) / b;}
dim3 block_dims{64, 16, 1};
dim3 grid_dims{CeilDiv(third_dim, 64), CeilDiv(second_dim, block_dims.y), first_dim};
ReduceL2Kernel<<<grid_dims, block_dims, 0, stream>>>((float*)out, (float*)in, first_dim, second_dim, third_dim);
型out
Tensor与python代码中的reduced
或reduce_second
Tensor不同。我不知道为什么,任何更正都是值得的。
1条答案
按热度按时间gzszwxb41#
简单地实现reducel2算子的正确方法如下:
字符串