2026/4/6 7:01:41
网站建设
项目流程
背景使用Triton实现一个向量累加triton.jitdefreduction_kernel(input,output,N:int,BLOCK_SIZE:tl.constexpr,num_warps:tl.constexpr,):pidtl.program_id(0)idxtl.arange(0,BLOCK_SIZE)offsetBLOCK_SIZE*pididx maskoffsetN atl.load(inputoffset,maskmask,other0.0)tl.atomic_add(output,tl.sum(a))它的原子加法竟然直接使用不需要进行tid0的判断它是怎么做到一个block只加一次的参考对应的C版本应该是这样的templatesize_t BlockSize__global__voidreduce(constfloat*input,float*output,intN){intidxblockIdx.x*blockDim.xthreadIdx.x;inttidthreadIdx.x;intwarpIdtid/WARP_SIZE;intlaneIdtid%WARP_SIZE;__shared__floats_sum[BlockSize/WARP_SIZE];floatsum(idxN)?0:input[idx];sumwarp_reduce_sum(sum);if(laneId0){s_sum[warpId]sum;}__syncthreads();sums_sum[laneId];sumwarp_reduce_sumBlockSize/WARP_SIZE(sum);if(threadIdx.x0){atomicAdd(output,sum);}}SASS代码分析基本代码如下IMAD.MOV.U32 R1,RZ,RZ,c[0x0][0x28]// R9 threadIdx.xS2R R9,SR_TID.X ULDC.64UR4,c[0x0][0x118]BSSY B0,0x40f5ff1e0MOV R4,RZ// R3 blockIdx.xS2R R3,SR_CTAID.X// R0 R3 0x7f(127)LOP3.LUT R0,R9,0x7f,RZ,0xc0,!PT IMAD R3,R3,0x80,R0 ISETP.GE.AND P0,PT,R3,c[0x0][0x170],PT LEA R2,P1,R3,c[0x0][0x160],0x2LEA.HI.X.SX32 R3,R3,c[0x0][0x164],0x2,P1 P0 BRA0x40f5ff1d0LDG.E R4,[R2.64]BSYNC B0// shfl_xor_syncSHFL.BFLY PT,R3,R4,0x10,0x1f// P1 R9 0x1fLOP3.LUT P1,RZ,R9,0x1f,RZ,0xc0,!PT ISETP.GE.U32.AND P0,PT,R0,0x4,PT FADD R3,R3,R4 SHF.R.U32.HI R4,RZ,0x3,R9 SHFL.BFLY PT,R2,R3,0x8,0x1fLOP3.LUT R4,R4,0xc,RZ,0xc0,!PT FADD R2,R3,R2 SHFL.BFLY PT,R5,R2,0x4,0x1fFADD R5,R2,R5 SHFL.BFLY PT,R6,R5,0x2,0x1fFADD R6,R5,R6 SHFL.BFLY PT,R7,R6,0x1,0x1fFADD R7,R6,R7// P1 false才执行也就是 R9 0x1f 0即laneId 0!P1 STS[R4],R7 BAR.SYNC.DEFER_BLOCKING0x0// P1 P1 R0 0x20ISETP.LT.U32.AND P1,PT,R0,0x20,!P1 !P0 LDS R8,[R0.X4]LOP3.LUT P0,RZ,R9,0x3,RZ,0xc0,!PT ISETP.LT.U32.AND P0,PT,R0,0x4,!P0 SHFL.BFLY PT,R3,R8,0x2,0x1fFADD R3,R8,R3 SHFL.BFLY PT,R2,R3,0x1,0x1fFADD R5,R3,R2 P0 STS[R0.X4],R5 BAR.SYNC.DEFER_BLOCKING0x0// P1 为false 时直接退出!P1 EXIT LDS R5,[RZ]IMAD.MOV.U32 R2,RZ,RZ,c[0x0][0x168]MOV R3,c[0x0][0x16c]MEMBAR.ALL.GPU ERRBAR ATOMG.E.ADD.F32.FTZ.RN.STRONG.GPU PT,R2,[R2.64],R5 CCTL.IVALL从代码中可以看到它执行了7次蝶式交换也就是异或交换前5次是warp内部的累加然后后面2次是4128/324128/324128/32然后ISETP.LT.U32.AND P1,PT,R0,0x20,!P1 !P1 EXIT两个条件一起成立才是P1laneId0tid32不满足这个条件的直接退出在PTX中更加明确setp.eq.b32%p2,%r14,0;%p4 st.shared.b32[%r60],%r7;bar.sync0;$L__tmp16:setp.lt.u32%p7,%r13,32;and.pred%p5,%p2,%p7;ld.shared.b32%r9,[global_smem];mov.u32%r8,0x0;%p5 atom.global.gpu.acq_rel.add.f32%r8,[%rd20],%r9;最后一个原子加法是有谓词的所以triton确实不涉及向量的原子加法确实一个block只会执行一次并且合理推测只要不涉及向量化操作的都只会执行一次