gpt4 book ai didi

go - 如何使用汇编优化这个 8 位位置 popcount?

转载 作者:行者123 更新时间:2023-12-01 19:18:41 25 4
gpt4 key购买 nike

此帖与 Golang assembly implement of _mm_add_epi32 相关,它在两个 [8]int32 中添加成对的元素列表,并返回更新后的第一个。
根据 pprof 资料,我发现通过 [8]int32很贵,所以我认为传递列表的指针便宜得多,bech 结果证实了这一点。这是 Go 版本:

func __mm_add_epi32_inplace_purego(x, y *[8]int32) {
(*x)[0] += (*y)[0]
(*x)[1] += (*y)[1]
(*x)[2] += (*y)[2]
(*x)[3] += (*y)[3]
(*x)[4] += (*y)[4]
(*x)[5] += (*y)[5]
(*x)[6] += (*y)[6]
(*x)[7] += (*y)[7]
}
该函数在两级循环中调用。
该算法计算一个字节数组的位置人口计数。
感谢@fuz 的建议,我知道在汇编中编写整个算法是最好的选择并且是有意义的,但这超出了我的能力范围,因为我从未学习过汇编编程。
但是,使用 assembly 优化内循环应该很容易:
counts := make([][8]int32, numRowBytes)

for i, b = range byteSlice {
if b == 0 { // more than half of elements in byteSlice is 0.
continue
}
expand = _expand_byte[b]
__mm_add_epi32_inplace_purego(&counts[i], expand)
}

// expands a byte into its bits
var _expand_byte = [256]*[8]int32{
&[8]int32{0, 0, 0, 0, 0, 0, 0, 0},
&[8]int32{0, 0, 0, 0, 0, 0, 0, 1},
&[8]int32{0, 0, 0, 0, 0, 0, 1, 0},
&[8]int32{0, 0, 0, 0, 0, 0, 1, 1},
&[8]int32{0, 0, 0, 0, 0, 1, 0, 0},
...
}
你能帮忙写一个 __mm_add_epi32_inplace_purego的汇编版本吗? (这对我来说已经足够了),甚至整个循环?先感谢您。

最佳答案

您要执行的操作称为基于字节的位置填充计数。这是机器学习中使用的众所周知的操作,并且已经在 fast algorithms 上进行了一些研究。来解决这个问题。
不幸的是,这些算法的实现相当复杂。出于这个原因,我开发了一种自定义算法,该算法实现起来要简单得多,但其性能仅为其他方法的大约一半。但是,以 10 GB/s 的测量速度计算,与之前相比,它应该仍然是一个不错的改进。
该算法的思想是使用 vpmovmskb 从 32 个字节的组中收集相应的位。然后取一个标量人口计数,然后将其添加到相应的计数器中。这允许依赖链很短并且可以达到一致的 IPC 3。
请注意,与您的算法相比,我的代码翻转了位的顺序。您可以通过编辑 which counts 来更改此设置。如果需要,汇编代码可以访问数组元素。但是,为了 future 读者的利益,我想保留这段代码的更常见约定,即最低有效位被视为位 0。
源代码
完整的源代码可以在on github找到.作者同时将这个算法思想发展成一个portable library可以这样使用:

import "github.com/clausecker/pospop"

var counts [8]int
pospop.Count8(counts, buf) // add positional popcounts for buf to counts
该算法提供两种变体,并已在处理器标识为“Intel(R) Xeon(R) W-2133 CPU @ 3.60GHz”的机器上进行测试。
位置总体计数一次 32 个字节。
计数器保存在通用寄存器中以获得最佳性能。内存提前预取以获得更好的流媒体行为。使用非常简单的 SHRL 处理标量尾部/ ADCL组合。实现了高达 11 GB/s 的性能。
#include "textflag.h"

// func PospopcntReg(counts *[8]int32, buf []byte)
TEXT ·PospopcntReg(SB),NOSPLIT,$0-32
MOVQ counts+0(FP), DI
MOVQ buf_base+8(FP), SI // SI = &buf[0]
MOVQ buf_len+16(FP), CX // CX = len(buf)

// load counts into register R8--R15
MOVL 4*0(DI), R8
MOVL 4*1(DI), R9
MOVL 4*2(DI), R10
MOVL 4*3(DI), R11
MOVL 4*4(DI), R12
MOVL 4*5(DI), R13
MOVL 4*6(DI), R14
MOVL 4*7(DI), R15

SUBQ $32, CX // pre-subtract 32 bit from CX
JL scalar

vector: VMOVDQU (SI), Y0 // load 32 bytes from buf
PREFETCHT0 384(SI) // prefetch some data
ADDQ $32, SI // advance SI past them

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R15 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R14 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R13 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R12 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R11 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R10 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R9 // add to counter
VPADDD Y0, Y0, Y0 // shift Y0 left by one place

VPMOVMSKB Y0, AX // move MSB of Y0 bytes to AX
POPCNTL AX, AX // count population of AX
ADDL AX, R8 // add to counter

SUBQ $32, CX
JGE vector // repeat as long as bytes are left

scalar: ADDQ $32, CX // undo last subtraction
JE done // if CX=0, there's nothing left

loop: MOVBLZX (SI), AX // load a byte from buf
INCQ SI // advance past it

SHRL $1, AX // CF=LSB, shift byte to the right
ADCL $0, R8 // add CF to R8

SHRL $1, AX
ADCL $0, R9 // add CF to R9

SHRL $1, AX
ADCL $0, R10 // add CF to R10

SHRL $1, AX
ADCL $0, R11 // add CF to R11

SHRL $1, AX
ADCL $0, R12 // add CF to R12

SHRL $1, AX
ADCL $0, R13 // add CF to R13

SHRL $1, AX
ADCL $0, R14 // add CF to R14

SHRL $1, AX
ADCL $0, R15 // add CF to R15

DECQ CX // mark this byte as done
JNE loop // and proceed if any bytes are left

// write R8--R15 back to counts
done: MOVL R8, 4*0(DI)
MOVL R9, 4*1(DI)
MOVL R10, 4*2(DI)
MOVL R11, 4*3(DI)
MOVL R12, 4*4(DI)
MOVL R13, 4*5(DI)
MOVL R14, 4*6(DI)
MOVL R15, 4*7(DI)

VZEROUPPER // restore SSE-compatibility
RET
使用 CSA 一次计算 96 个字节的位置种群
此变体执行上述所有优化,但预先使用单个 CSA 步骤将 96 个字节减少到 64 个。正如预期的那样,这将性能提高了大约 30%,并达到了 16 GB/s。
#include "textflag.h"

// func PospopcntRegCSA(counts *[8]int32, buf []byte)
TEXT ·PospopcntRegCSA(SB),NOSPLIT,$0-32
MOVQ counts+0(FP), DI
MOVQ buf_base+8(FP), SI // SI = &buf[0]
MOVQ buf_len+16(FP), CX // CX = len(buf)

// load counts into register R8--R15
MOVL 4*0(DI), R8
MOVL 4*1(DI), R9
MOVL 4*2(DI), R10
MOVL 4*3(DI), R11
MOVL 4*4(DI), R12
MOVL 4*5(DI), R13
MOVL 4*6(DI), R14
MOVL 4*7(DI), R15

SUBQ $96, CX // pre-subtract 32 bit from CX
JL scalar

vector: VMOVDQU (SI), Y0 // load 96 bytes from buf into Y0--Y2
VMOVDQU 32(SI), Y1
VMOVDQU 64(SI), Y2
ADDQ $96, SI // advance SI past them
PREFETCHT0 320(SI)
PREFETCHT0 384(SI)

VPXOR Y0, Y1, Y3 // first adder: sum
VPAND Y0, Y1, Y0 // first adder: carry out
VPAND Y2, Y3, Y1 // second adder: carry out
VPXOR Y2, Y3, Y2 // second adder: sum (full sum)
VPOR Y0, Y1, Y0 // full adder: carry out

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
VPADDB Y0, Y0, Y0 // shift carry out bytes left
VPADDB Y2, Y2, Y2 // shift sum bytes left
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R15

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
VPADDB Y0, Y0, Y0 // shift carry out bytes left
VPADDB Y2, Y2, Y2 // shift sum bytes left
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R14

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
VPADDB Y0, Y0, Y0 // shift carry out bytes left
VPADDB Y2, Y2, Y2 // shift sum bytes left
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R13

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
VPADDB Y0, Y0, Y0 // shift carry out bytes left
VPADDB Y2, Y2, Y2 // shift sum bytes left
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R12

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
VPADDB Y0, Y0, Y0 // shift carry out bytes left
VPADDB Y2, Y2, Y2 // shift sum bytes left
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R11

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
VPADDB Y0, Y0, Y0 // shift carry out bytes left
VPADDB Y2, Y2, Y2 // shift sum bytes left
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R10

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
VPADDB Y0, Y0, Y0 // shift carry out bytes left
VPADDB Y2, Y2, Y2 // shift sum bytes left
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R9

VPMOVMSKB Y0, AX // MSB of carry out bytes
VPMOVMSKB Y2, DX // MSB of sum bytes
POPCNTL AX, AX // carry bytes population count
POPCNTL DX, DX // sum bytes population count
LEAL (DX)(AX*2), AX // sum popcount plus 2x carry popcount
ADDL AX, R8

SUBQ $96, CX
JGE vector // repeat as long as bytes are left

scalar: ADDQ $96, CX // undo last subtraction
JE done // if CX=0, there's nothing left

loop: MOVBLZX (SI), AX // load a byte from buf
INCQ SI // advance past it

SHRL $1, AX // is bit 0 set?
ADCL $0, R8 // add it to R8

SHRL $1, AX // is bit 0 set?
ADCL $0, R9 // add it to R9

SHRL $1, AX // is bit 0 set?
ADCL $0, R10 // add it to R10

SHRL $1, AX // is bit 0 set?
ADCL $0, R11 // add it to R11

SHRL $1, AX // is bit 0 set?
ADCL $0, R12 // add it to R12

SHRL $1, AX // is bit 0 set?
ADCL $0, R13 // add it to R13

SHRL $1, AX // is bit 0 set?
ADCL $0, R14 // add it to R14

SHRL $1, AX // is bit 0 set?
ADCL $0, R15 // add it to R15

DECQ CX // mark this byte as done
JNE loop // and proceed if any bytes are left

// write R8--R15 back to counts
done: MOVL R8, 4*0(DI)
MOVL R9, 4*1(DI)
MOVL R10, 4*2(DI)
MOVL R11, 4*3(DI)
MOVL R12, 4*4(DI)
MOVL R13, 4*5(DI)
MOVL R14, 4*6(DI)
MOVL R15, 4*7(DI)

VZEROUPPER // restore SSE-compatibility
RET
基准
以下是这两种算法的基准测试和纯 Go 中的一个简单引用实现。完整的基准测试可以在 github 存储库中找到。
BenchmarkReference/10-12    12448764            80.9 ns/op   123.67 MB/s
BenchmarkReference/32-12 4357808 258 ns/op 124.25 MB/s
BenchmarkReference/1000-12 151173 7889 ns/op 126.76 MB/s
BenchmarkReference/2000-12 68959 15774 ns/op 126.79 MB/s
BenchmarkReference/4000-12 36481 31619 ns/op 126.51 MB/s
BenchmarkReference/10000-12 14804 78917 ns/op 126.72 MB/s
BenchmarkReference/100000-12 1540 789450 ns/op 126.67 MB/s
BenchmarkReference/10000000-12 14 77782267 ns/op 128.56 MB/s
BenchmarkReference/1000000000-12 1 7781360044 ns/op 128.51 MB/s
BenchmarkReg/10-12 49255107 24.5 ns/op 407.42 MB/s
BenchmarkReg/32-12 186935192 6.40 ns/op 4998.53 MB/s
BenchmarkReg/1000-12 8778610 115 ns/op 8677.33 MB/s
BenchmarkReg/2000-12 5358495 208 ns/op 9635.30 MB/s
BenchmarkReg/4000-12 3385945 357 ns/op 11200.23 MB/s
BenchmarkReg/10000-12 1298670 901 ns/op 11099.24 MB/s
BenchmarkReg/100000-12 115629 8662 ns/op 11544.98 MB/s
BenchmarkReg/10000000-12 1270 916817 ns/op 10907.30 MB/s
BenchmarkReg/1000000000-12 12 93609392 ns/op 10682.69 MB/s
BenchmarkRegCSA/10-12 48337226 23.9 ns/op 417.92 MB/s
BenchmarkRegCSA/32-12 12843939 80.2 ns/op 398.86 MB/s
BenchmarkRegCSA/1000-12 7175629 150 ns/op 6655.70 MB/s
BenchmarkRegCSA/2000-12 3988408 295 ns/op 6776.20 MB/s
BenchmarkRegCSA/4000-12 3016693 382 ns/op 10467.41 MB/s
BenchmarkRegCSA/10000-12 1810195 642 ns/op 15575.65 MB/s
BenchmarkRegCSA/100000-12 191974 6229 ns/op 16053.40 MB/s
BenchmarkRegCSA/10000000-12 1622 698856 ns/op 14309.10 MB/s
BenchmarkRegCSA/1000000000-12 16 68540642 ns/op 14589.88 MB/s

关于go - 如何使用汇编优化这个 8 位位置 popcount?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63248047/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com