gpt4 book ai didi

c - 高效计算三个无符号整数的平均值(无溢出)

转载 作者:行者123 更新时间:2023-12-03 07:32:23 24 4
gpt4 key购买 nike

有一个existing question “3 个长整数的平均值”特别关注三个有符号整数的平均值的有效计算。
然而,无符号整数的使用允许额外的优化不适用于上一个问题中涵盖的场景。这个问题是关于三个无符号整数的平均值的有效计算,其中平均值向零舍入,即在数学术语中我想计算 ⌊ (a + b + c)/3 ⌋。
计算此平均值的一种直接方法是

 avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
首先,现代优化编译器会将除法转换为具有倒数加移位的乘法,并将模运算转换为反向乘法和减法,其中反向乘法可能使用许多体系结构上可用的 scale_add 习语,例如 lea在 x86_64 上, addlsl #n在 ARM 上, iscadd在 NVIDIA GPU 上。
在尝试以适用于许多常见平台的通用方式优化上述内容时,我观察到整数运算的成本通常处于逻辑 ≤ (add | sub) ≤ shift ≤ scale_add ≤ mul 的关系中。这里的成本是指所有延迟、吞吐量限制和功耗。当处理的整数类型比 native 寄存器宽度更宽时,任何此类差异都会变得更加明显,例如加工时 uint64_t 32 位处理器上的数据。
因此,我的优化策略是尽量减少指令数量,并在可能的情况下用“廉价”操作替换“昂贵”操作,同时不增加寄存器压力并为宽无序处理器保留可利用的并行性。
第一个观察结果是,我们可以通过首先应用产生和值和进位值的 CSA(进位保存加法器)将三个操作数的和减少为两个操作数的和,其中进位值的权重是和的两倍值(value)。在大多数处理器上,基于软件的 CSA 的成本是五个逻辑。某些处理器,例如 NVIDIA GPU,具有 LOP3可以一举计算三个操作数的任意逻辑表达式的指令,在这种情况下,CSA 压缩为两个 LOP3 s(注意:我还没有说服 CUDA 编译器发出这两个 LOP3 s;它目前产生四个 LOP3 s!)。
第二个观察是因为我们正在计算除以 3 的模数,所以我们不需要反向乘法来计算它。我们可以改用 dividend % 3 = ((dividend / 3) + dividend) & 3 ,将模减少为加法加上逻辑,因为我们已经有了除法结果。这是通用算法的一个实例:股息 % (2n-1) = ((股息/(2n-1) + 股息) & (2n-1)。
最后在修正项中除以 3 (a % 3 + b % 3 + c % 3) / 3我们不需要泛型除以3的代码。由于被除数很小,在[0, 6]中,我们可以简化 x / 3进入 (3 * x) / 8这只需要一个 scale_add 加上一个类次。
下面的代码显示了我目前正在进行的工作。使用编译器资源管理器检查为各种平台生成的代码显示了我期望的紧凑代码(当使用 -O3 编译时)。
然而,在使用 Intel 13.x 编译器对我的 Ivy Bridge x86_64 机器上的代码进行计时时,一个缺陷变得明显:虽然与简单版本相比,我的代码改善了延迟( uint64_t 数据从 18 个周期减少到 15 个周期),但吞吐量恶化(对于 uint64_t 数据,从每 6.8 个周期一个结果到每 8.5 个周期一个结果)。更仔细地查看汇编代码,很明显为什么会这样:我基本上设法将代码从大致三路并行降低到大致两路并行。
是否有一种通用的优化技术,对通用处理器特别是 x86 和 ARM 以及 GPU 的所有版本都有好处,可以保留更多的并行性?或者,是否有一种优化技术可以进一步减少总操作数以弥补减少的并行度?校正项的计算(下面代码中的 tail)似乎是一个很好的目标。简化 (carry_mod_3 + sum_mod_3) / 2看起来很诱人,但为九种可能的组合之一提供了错误的结果。
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>

#define BENCHMARK (1)
#define SIMPLE_COMPUTATION (0)

#if BENCHMARK
#define T uint64_t
#else // !BENCHMARK
#define T uint8_t
#endif // BENCHMARK

T average_of_3 (T a, T b, T c)
{
T avg;

#if SIMPLE_COMPUTATION
avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
#else // !SIMPLE_COMPUTATION
/* carry save adder */
T a_xor_b = a ^ b;
T sum = a_xor_b ^ c;
T carry = (a_xor_b & c) | (a & b);
/* here 2 * carry + sum = a + b + c */
T sum_div_3 = (sum / 3); // {MUL|MULHI}, SHR
T sum_mod_3 = (sum + sum_div_3) & 3; // ADD, AND

if (sizeof (size_t) == sizeof (T)) { // "native precision" (well, not always)
T two_carry_div_3 = (carry / 3) * 2; // MULHI, ANDN
T two_carry_mod_3 = (2 * carry + two_carry_div_3) & 6; // SCALE_ADD, AND
T head = two_carry_div_3 + sum_div_3; // ADD
T tail = (3 * (two_carry_mod_3 + sum_mod_3)) / 8; // ADD, SCALE_ADD, SHR
avg = head + tail; // ADD
} else {
T carry_div_3 = (carry / 3); // MUL, SHR
T carry_mod_3 = (carry + carry_div_3) & 3; // ADD, AND
T head = (2 * carry_div_3 + sum_div_3); // SCALE_ADD
T tail = (3 * (2 * carry_mod_3 + sum_mod_3)) / 8; // SCALE_ADD, SCALE_ADD, SHR
avg = head + tail; // ADD
}
#endif // SIMPLE_COMPUTATION
return avg;
}

#if !BENCHMARK
/* Test correctness on 8-bit data exhaustively. Should catch most errors */
int main (void)
{
T a, b, c, res, ref;
a = 0;
do {
b = 0;
do {
c = 0;
do {
res = average_of_3 (a, b, c);
ref = ((uint64_t)a + (uint64_t)b + (uint64_t)c) / 3;
if (res != ref) {
printf ("a=%08x b=%08x c=%08x res=%08x ref=%08x\n",
a, b, c, res, ref);
return EXIT_FAILURE;
}
c++;
} while (c);
b++;
} while (b);
a++;
} while (a);
return EXIT_SUCCESS;
}

#else // BENCHMARK

#include <math.h>

// A routine to give access to a high precision timer on most systems.
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
LARGE_INTEGER t;
static double oofreq;
static int checkedForHighResTimer;
static BOOL hasHighResTimer;

if (!checkedForHighResTimer) {
hasHighResTimer = QueryPerformanceFrequency (&t);
oofreq = 1.0 / (double)t.QuadPart;
checkedForHighResTimer = 1;
}
if (hasHighResTimer) {
QueryPerformanceCounter (&t);
return (double)t.QuadPart * oofreq;
} else {
return (double)GetTickCount() * 1.0e-3;
}
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
struct timeval tv;
gettimeofday(&tv, NULL);
return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

#define N (3000000)
int main (void)
{
double start, stop, elapsed = INFINITY;
int i, k;
T a, b;
T avg0 = 0xffffffff, avg1 = 0xfffffffe;
T avg2 = 0xfffffffd, avg3 = 0xfffffffc;
T avg4 = 0xfffffffb, avg5 = 0xfffffffa;
T avg6 = 0xfffffff9, avg7 = 0xfffffff8;
T avg8 = 0xfffffff7, avg9 = 0xfffffff6;
T avg10 = 0xfffffff5, avg11 = 0xfffffff4;
T avg12 = 0xfffffff2, avg13 = 0xfffffff2;
T avg14 = 0xfffffff1, avg15 = 0xfffffff0;

a = 0x31415926;
b = 0x27182818;
avg0 = average_of_3 (a, b, avg0);
for (k = 0; k < 5; k++) {
start = second();
for (i = 0; i < N; i++) {
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
b = (b + avg0) ^ a;
a = (a ^ b) + avg0;
}
stop = second();
elapsed = fmin (stop - start, elapsed);
}
printf ("a=%016llx b=%016llx avg=%016llx",
(uint64_t)a, (uint64_t)b, (uint64_t)avg0);
printf ("\rlatency: each average_of_3() took %.6e seconds\n",
elapsed / 16 / N);


a = 0x31415926;
b = 0x27182818;
avg0 = average_of_3 (a, b, avg0);
for (k = 0; k < 5; k++) {
start = second();
for (i = 0; i < N; i++) {
avg0 = average_of_3 (a, b, avg0);
avg1 = average_of_3 (a, b, avg1);
avg2 = average_of_3 (a, b, avg2);
avg3 = average_of_3 (a, b, avg3);
avg4 = average_of_3 (a, b, avg4);
avg5 = average_of_3 (a, b, avg5);
avg6 = average_of_3 (a, b, avg6);
avg7 = average_of_3 (a, b, avg7);
avg8 = average_of_3 (a, b, avg8);
avg9 = average_of_3 (a, b, avg9);
avg10 = average_of_3 (a, b, avg10);
avg11 = average_of_3 (a, b, avg11);
avg12 = average_of_3 (a, b, avg12);
avg13 = average_of_3 (a, b, avg13);
avg14 = average_of_3 (a, b, avg14);
avg15 = average_of_3 (a, b, avg15);
b = (b + avg0) ^ a;
a = (a ^ b) + avg0;
}
stop = second();
elapsed = fmin (stop - start, elapsed);
}
printf ("a=%016llx b=%016llx avg=%016llx", (uint64_t)a, (uint64_t)b,
(uint64_t)(avg0 + avg1 + avg2 + avg3 + avg4 + avg5 + avg6 + avg7 +
avg8 + avg9 +avg10 +avg11 +avg12 +avg13 +avg14 +avg15));
printf ("\rthroughput: each average_of_3() took %.6e seconds\n",
elapsed / 16 / N);

return EXIT_SUCCESS;
}

#endif // BENCHMARK

最佳答案

让我把帽子扔进戒指。在这里不要做任何太棘手的事情,我
思考。

#include <stdint.h>

uint64_t average_of_three(uint64_t a, uint64_t b, uint64_t c) {
uint64_t hi = (a >> 32) + (b >> 32) + (c >> 32);
uint64_t lo = hi + (a & 0xffffffff) + (b & 0xffffffff) + (c & 0xffffffff);
return 0x55555555 * hi + lo / 3;
}
在下面关于不同拆分的讨论之后,这是一个以三个按位与为代价来节省乘法的版本:
T hi = (a >> 2) + (b >> 2) + (c >> 2);
T lo = (a & 3) + (b & 3) + (c & 3);
avg = hi + (hi + lo) / 3;

关于c - 高效计算三个无符号整数的平均值(无溢出),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64563085/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com