- c - 在位数组中找到第一个零
- linux - Unix 显示有关匹配两种模式之一的文件的信息
- 正则表达式替换多个文件
- linux - 隐藏来自 xtrace 的命令
我试图找到潜在阶乘素数的除数(n!+-1 形式的数),因为我最近购买了 Skylake-X 工作站,我认为我可以使用 AVX512 指令提高一些速度。
算法简单,主要步骤是对同一个除数重复取模。主要是遍历大范围的 n 值。这是用 c 编写的朴素方法(P 是素数表):
uint64_t factorial_naive(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i, residue;
for (i = 0; i < APP_BUFLEN; i++){
residue = 2;
for (n=3; n <= nmax; n++){
residue *= n;
residue %= P[i];
// Lets check if we found factor
if (nmin <= n){
if( residue == 1){
report_factor(n, -1, P[i]);
}
if(residue == P[i]- 1){
report_factor(n, 1, P[i]);
}
}
}
}
return EXIT_SUCCESS;
}
这里的想法是检查大范围的 n,例如1,000,000 -> 10,000,000 针对同一组除数。因此,我们将对同一个除数取模数百万次。使用 DIV 非常慢,因此根据计算范围有几种可能的方法。在我的例子中,n 很可能小于 10^7 并且潜在除数 p 小于 10,000 G(< 10^13),所以数字小于 64 位,也小于 53 位!,但是乘积最大余数 (p-1) 乘以 n 大于 64 位。所以我认为蒙哥马利方法的最简单版本不起作用,因为我们从大于 64 位的数字中取模。
我发现了一些用于 power pc 的旧代码,其中在使用 double 时使用 FMA 来获得高达 106 位(我猜)的精确乘积。所以我将这种方法转换为 AVX 512 汇编器(Intel Intrinsics)。这是 FMA 方法的一个简单版本,它基于 Dekker (1971) 的工作,Dekker 产品和 TwoProduct 的 FMA 版本在尝试查找/谷歌搜索背后的基本原理时是有用的词。此论坛也讨论了这种方法(例如 here)。
int64_t factorial_FMA(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i;
double prime_double, prime_double_reciprocal, quotient, residue;
double nr, n_double, prime_times_quotient_high, prime_times_quotient_low;
for (i = 0; i < APP_BUFLEN; i++){
residue = 2.0;
prime_double = (double)P[i];
prime_double_reciprocal = 1.0 / prime_double;
n_double = 3.0;
for (n=3; n <= nmax; n++){
nr = n_double * residue;
quotient = fma(nr, prime_double_reciprocal, rounding_constant);
quotient -= rounding_constant;
prime_times_quotient_high= prime_double * quotient;
prime_times_quotient_low = fma(prime_double, quotient, -prime_times_quotient_high);
residue = fma(residue, n, -prime_times_quotient_high) - prime_times_quotient_low;
if (residue < 0.0) residue += prime_double;
n_double += 1.0;
// Lets check if we found factor
if (nmin <= n){
if( residue == 1.0){
report_factor(n, -1, P[i]);
}
if(residue == prime_double - 1.0){
report_factor(n, 1, P[i]);
}
}
}
}
return EXIT_SUCCESS;
}
这里我使用了魔法常量
static const double rounding_constant = 6755399441055744.0;
对于 double ,这是 2^51 + 2^52 魔数(Magic Number)。
我将其转换为 AVX512(每个循环 32 个潜在除数)并使用 IACA 分析结果。它告诉吞吐量瓶颈:后端和后端分配由于分配资源不可用而停滞。我对汇编程序不是很有经验,所以我的问题是我能做些什么来加快速度并解决这个后端瓶颈吗?
AVX512 代码在这里,也可以从 github 找到
uint64_t factorial_AVX512_unrolled_four(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
// we are trying to find a factor for a factorial numbers : n! +-1
//nmin is minimum n we want to report and nmax is maximum. P is table of primes
// we process 32 primes in one loop.
// naive version of the algorithm is int he function factorial_naive
// and simple version of the FMA based approach in the function factorial_simpleFMA
const double one_table[8] __attribute__ ((aligned(64))) ={1.0, 1.0, 1.0,1.0,1.0,1.0,1.0,1.0};
uint64_t n;
__m512d zero, rounding_const, one, n_double;
__m512i prime1, prime2, prime3, prime4;
__m512d residue1, residue2, residue3, residue4;
__m512d prime_double_reciprocal1, prime_double_reciprocal2, prime_double_reciprocal3, prime_double_reciprocal4;
__m512d quotient1, quotient2, quotient3, quotient4;
__m512d prime_times_quotient_high1, prime_times_quotient_high2, prime_times_quotient_high3, prime_times_quotient_high4;
__m512d prime_times_quotient_low1, prime_times_quotient_low2, prime_times_quotient_low3, prime_times_quotient_low4;
__m512d nr1, nr2, nr3, nr4;
__m512d prime_double1, prime_double2, prime_double3, prime_double4;
__m512d prime_minus_one1, prime_minus_one2, prime_minus_one3, prime_minus_one4;
__mmask8 negative_reminder_mask1, negative_reminder_mask2, negative_reminder_mask3, negative_reminder_mask4;
__mmask8 found_factor_mask11, found_factor_mask12, found_factor_mask13, found_factor_mask14;
__mmask8 found_factor_mask21, found_factor_mask22, found_factor_mask23, found_factor_mask24;
// load data and initialize cariables for loop
rounding_const = _mm512_set1_pd(rounding_constant);
one = _mm512_load_pd(one_table);
zero = _mm512_setzero_pd ();
// load primes used to sieve
prime1 = _mm512_load_epi64((__m512i *) &P[0]);
prime2 = _mm512_load_epi64((__m512i *) &P[8]);
prime3 = _mm512_load_epi64((__m512i *) &P[16]);
prime4 = _mm512_load_epi64((__m512i *) &P[24]);
// convert primes to double
prime_double1 = _mm512_cvtepi64_pd (prime1); // vcvtqq2pd
prime_double2 = _mm512_cvtepi64_pd (prime2); // vcvtqq2pd
prime_double3 = _mm512_cvtepi64_pd (prime3); // vcvtqq2pd
prime_double4 = _mm512_cvtepi64_pd (prime4); // vcvtqq2pd
// calculates 1.0/ prime
prime_double_reciprocal1 = _mm512_div_pd(one, prime_double1);
prime_double_reciprocal2 = _mm512_div_pd(one, prime_double2);
prime_double_reciprocal3 = _mm512_div_pd(one, prime_double3);
prime_double_reciprocal4 = _mm512_div_pd(one, prime_double4);
// for comparison if we have found factors for n!+1
prime_minus_one1 = _mm512_sub_pd(prime_double1, one);
prime_minus_one2 = _mm512_sub_pd(prime_double2, one);
prime_minus_one3 = _mm512_sub_pd(prime_double3, one);
prime_minus_one4 = _mm512_sub_pd(prime_double4, one);
// residue init
residue1 = _mm512_set1_pd(2.0);
residue2 = _mm512_set1_pd(2.0);
residue3 = _mm512_set1_pd(2.0);
residue4 = _mm512_set1_pd(2.0);
// double counter init
n_double = _mm512_set1_pd(3.0);
// main loop starts here. typical value for nmax can be 5,000,000 -> 10,000,000
for (n=3; n<=nmax; n++) // main loop
{
// timings for instructions:
// _mm512_load_epi64 = vmovdqa64 : L 1, T 0.5
// _mm512_load_pd = vmovapd : L 1, T 0.5
// _mm512_set1_pd
// _mm512_div_pd = vdivpd : L 23, T 16
// _mm512_cvtepi64_pd = vcvtqq2pd : L 4, T 0,5
// _mm512_mul_pd = vmulpd : L 4, T 0.5
// _mm512_fmadd_pd = vfmadd132pd, vfmadd213pd, vfmadd231pd : L 4, T 0.5
// _mm512_fmsub_pd = vfmsub132pd, vfmsub213pd, vfmsub231pd : L 4, T 0.5
// _mm512_sub_pd = vsubpd : L 4, T 0.5
// _mm512_cmplt_pd_mask = vcmppd : L ?, Y 1
// _mm512_mask_add_pd = vaddpd : L 4, T 0.5
// _mm512_cmpeq_pd_mask = vcmppd L ?, Y 1
// _mm512_kor = korw L 1, T 1
// nr = residue * n
nr1 = _mm512_mul_pd (residue1, n_double);
nr2 = _mm512_mul_pd (residue2, n_double);
nr3 = _mm512_mul_pd (residue3, n_double);
nr4 = _mm512_mul_pd (residue4, n_double);
// quotient = nr * 1.0/ prime_double + rounding_constant
quotient1 = _mm512_fmadd_pd(nr1, prime_double_reciprocal1, rounding_const);
quotient2 = _mm512_fmadd_pd(nr2, prime_double_reciprocal2, rounding_const);
quotient3 = _mm512_fmadd_pd(nr3, prime_double_reciprocal3, rounding_const);
quotient4 = _mm512_fmadd_pd(nr4, prime_double_reciprocal4, rounding_const);
// quotient -= rounding_constant, now quotient is rounded to integer
// countient should be at maximum nmax (10,000,000)
quotient1 = _mm512_sub_pd(quotient1, rounding_const);
quotient2 = _mm512_sub_pd(quotient2, rounding_const);
quotient3 = _mm512_sub_pd(quotient3, rounding_const);
quotient4 = _mm512_sub_pd(quotient4, rounding_const);
// now we calculate high and low for prime * quotient using decker product (FMA).
// quotient is calculated using approximation but this is accurate for given quotient
prime_times_quotient_high1 = _mm512_mul_pd(quotient1, prime_double1);
prime_times_quotient_high2 = _mm512_mul_pd(quotient2, prime_double2);
prime_times_quotient_high3 = _mm512_mul_pd(quotient3, prime_double3);
prime_times_quotient_high4 = _mm512_mul_pd(quotient4, prime_double4);
prime_times_quotient_low1 = _mm512_fmsub_pd(quotient1, prime_double1, prime_times_quotient_high1);
prime_times_quotient_low2 = _mm512_fmsub_pd(quotient2, prime_double2, prime_times_quotient_high2);
prime_times_quotient_low3 = _mm512_fmsub_pd(quotient3, prime_double3, prime_times_quotient_high3);
prime_times_quotient_low4 = _mm512_fmsub_pd(quotient4, prime_double4, prime_times_quotient_high4);
// now we calculate new reminder using decker product and using original values
// we subtract above calculated prime * quotient (quotient is aproximation)
residue1 = _mm512_fmsub_pd(residue1, n_double, prime_times_quotient_high1);
residue2 = _mm512_fmsub_pd(residue2, n_double, prime_times_quotient_high2);
residue3 = _mm512_fmsub_pd(residue3, n_double, prime_times_quotient_high3);
residue4 = _mm512_fmsub_pd(residue4, n_double, prime_times_quotient_high4);
residue1 = _mm512_sub_pd(residue1, prime_times_quotient_low1);
residue2 = _mm512_sub_pd(residue2, prime_times_quotient_low2);
residue3 = _mm512_sub_pd(residue3, prime_times_quotient_low3);
residue4 = _mm512_sub_pd(residue4, prime_times_quotient_low4);
// lets check if reminder < 0
negative_reminder_mask1 = _mm512_cmplt_pd_mask(residue1,zero);
negative_reminder_mask2 = _mm512_cmplt_pd_mask(residue2,zero);
negative_reminder_mask3 = _mm512_cmplt_pd_mask(residue3,zero);
negative_reminder_mask4 = _mm512_cmplt_pd_mask(residue4,zero);
// we and prime back to reminder using mask if it was < 0
residue1 = _mm512_mask_add_pd(residue1, negative_reminder_mask1, residue1, prime_double1);
residue2 = _mm512_mask_add_pd(residue2, negative_reminder_mask2, residue2, prime_double2);
residue3 = _mm512_mask_add_pd(residue3, negative_reminder_mask3, residue3, prime_double3);
residue4 = _mm512_mask_add_pd(residue4, negative_reminder_mask4, residue4, prime_double4);
n_double = _mm512_add_pd(n_double,one);
// if we are below nmin then we continue next iteration
if (n < nmin) continue;
// Lets check if we found any factors, residue 1 == n!-1
found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);
// residue prime -1 == n!+1
found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);
if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
{ // we find factor very rarely
double *residual_list1 = (double *) &residue1;
double *residual_list2 = (double *) &residue2;
double *residual_list3 = (double *) &residue3;
double *residual_list4 = (double *) &residue4;
double *prime_list1 = (double *) &prime_double1;
double *prime_list2 = (double *) &prime_double2;
double *prime_list3 = (double *) &prime_double3;
double *prime_list4 = (double *) &prime_double4;
for (int i=0; i <8; i++){
if( residual_list1[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list1[i]);
}
if( residual_list2[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list2[i]);
}
if( residual_list3[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list3[i]);
}
if( residual_list4[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list4[i]);
}
if(residual_list1[i] == (prime_list1[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list1[i]);
}
if(residual_list2[i] == (prime_list2[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list2[i]);
}
if(residual_list3[i] == (prime_list3[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list3[i]);
}
if(residual_list4[i] == (prime_list4[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list4[i]);
}
}
}
}
return EXIT_SUCCESS;
}
最佳答案
正如一些评论者所建议的那样:“后端”瓶颈是您对这段代码的期望。这表明您的食物已经吃饱了,这正是您想要的。
看报告,这一段应该有机会:
// Lets check if we found any factors, residue 1 == n!-1
found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);
// residue prime -1 == n!+1
found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);
if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
来自 IACA 的分析:
| 1 | 1.0 | | | | | | | | kmovw r11d, k0
| 1 | 1.0 | | | | | | | | kmovw eax, k1
| 1 | 1.0 | | | | | | | | kmovw ecx, k2
| 1 | 1.0 | | | | | | | | kmovw esi, k3
| 1 | 1.0 | | | | | | | | kmovw edi, k4
| 1 | 1.0 | | | | | | | | kmovw r8d, k5
| 1 | 1.0 | | | | | | | | kmovw r9d, k6
| 1 | 1.0 | | | | | | | | kmovw r10d, k7
| 1 | | 1.0 | | | | | | | or r11d, eax
| 1 | | | | | | | 1.0 | | or r11d, ecx
| 1 | | 1.0 | | | | | | | or r11d, esi
| 1 | | | | | | | 1.0 | | or r11d, edi
| 1 | | 1.0 | | | | | | | or r11d, r8d
| 1 | | | | | | | 1.0 | | or r11d, r9d
| 1* | | | | | | | | | or r11d, r10d
处理器正在将生成的比较掩码 (k0-k7) 移至常规寄存器以进行“或”运算。您应该能够消除这些移动,并且在 6ops 和 8ops 中进行“或”汇总。
注意:found_factor_mask 类型定义为 __mmask8
,它们应该是 __mask16
(512 位 vector 中的 16x 双 float )。这可能会让编译器进行一些优化。如果没有,请按照评论者的说明进行组装。
相关的是:多少次迭代触发了这个 or-mask 子句?正如另一位评论者所观察到的,您应该能够通过累积“或”操作来展开它。在每次展开的迭代结束时(或 N 次迭代后)检查累积的“或”值,如果为“真”,则返回并重新计算值以确定是哪个 n 值触发了它。
(并且,您可以在“roll”中进行二进制搜索以找到匹配的 n 值——这可能会有所收获)。
接下来,您应该能够摆脱这种中间循环检查:
// if we are below nmin then we continue next iteration, we
if (n < nmin) continue;
显示在这里:
| 1* | | | | | | | | | cmp r14, 0x3e8
| 0*F | | | | | | | | | jb 0x229
这可能不是一个巨大的 yield ,因为预测器(可能)会(大部分)正确地得到这个,但是您应该通过为两个“阶段”设置两个不同的循环来获得一些 yield :
即使您获得了一个周期,也只有 3%。由于这通常与上面的大“或”运算有关,因此可能会发现更多的聪明之处。
关于c - 相同除数时快速 AVX512 模,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47855357/
是 if(a == 0 && b == 0 && c == 0) { return; } 一样 if(a == 0) { return; } if(b == 0) { return; } if(c =
我想做这样的事情: Class A Class B extends A Class C extends A B b = new B(); C c = new C(); b->setField("foo
我对 Mysql 世界很天真......:)我试图使用连接从表中查询, 我遇到结果集问题...表结构如下 下面... VIDEO_XXXXX | Field | Type
我最近问过关于从另一个类获取类的唯一实例的问题。 ( How to get specific instance of class from another class in Java? ) 所以,我正
假设我们有两种类型 using t1 = int*; using t2 = int*; 我知道 std::is_same::value会给我们true .什么是,或者是否有模板工具可以实现以下目标?
对于我的一个应用程序,我假设比较 2 个字符串的第一个字符比比较整个字符串是否相等要快。例如,如果我知道只有 2 个可能的字符串(在一组 n 字符串中)可以以相同的字母开头(比如说 'q'),如果是这
我想在我的NXP LPC11U37H主板(ARM Cortex-M0)上分析一些算法,因为我想知道执行特定算法需要多少个时钟周期。 我编写了这些简单的宏来进行一些分析: #define START_C
我在 Excel 中创建了一个宏,它将在 Excel 中复制一个表格,并将行除以我确定的特定数字(默认 = 500 行),并为宏创建的每个部门打开不同的工作表。 使用的代码是这样的: Sub Copy
我想根据第一个字典对第二个字典的值求和。如果我有字典 A 和 B。 A = {"Mark": ["a", "b", "c", "d"], "June": ["e", "a"], "John": ["a
当我这样做时 system()在 Perl 中调用,我通常根据 perldocs 检查返回码.嗯,我是这么想的。大部分时间 $rc!=0对我来说已经足够了。最近我在这里帮助了两个遇到问题的人syste
在我的进度条上,我试图让它检测 div 加载速度。 如果 div 加载速度很快,我想要实现的目标将很快达到 100%。但进度条的加载速度应该与 div 的加载速度一样快。 问题:如何让我的进度条加载
当我获得与本地时间相同的时间戳时,firebase 生成的服务器时间戳是否会自动转换为本地时间,或者我错过了什么? _firestore.collection("9213903123").docume
根据the original OWL definition of OWL DL ,我们不能为类和个体赋予相同的名称(这是 OWL DL 和 OWL Full 之间的明显区别)。 "Punning" i
我有两个输入复选框: 尝试使用 jQuery 来允许两个输入的行为相同。如果选中第一个复选框,则选中第二个复选框。如果未检查第 1 个,则不会检查第 2 个。反之亦然。 我有代码: $('inpu
可以从不同系统编译两个相同的java文件,但它们都有相同的内容操作系统(Windows 7),会生成不同的.class文件(大小)? 最佳答案 是的,您可以检查是否有不同版本的JDK(Java Dev
我正在清理另一个人的正则表达式,他们目前所有的都以结尾 .*$ 那么下面的不是完全一样吗? .* 最佳答案 .*将尽可能匹配,但默认情况下为 .不匹配换行符。如果您要匹配的文本有换行符并且您处于 MU
我使用 Pick ,但是如何编写可以选择多个字段的通用PickMulti呢? interface MyInterface { a: number, b: number, c: number
我有一个 SQL 数据库服务器和 2 个具有相同结构和数据的数据库。我在 2 个数据库中运行相同的 sql 查询,其中一个需要更长的时间,而另一个在不到 50% 的时间内完成。他们都有不同的执行计划。
我需要你的帮助,我有一个包含两列的表,一个 id 和 numpos,我希望 id 和 numops 具有相同的结果。 例子: $cnx = mysql_connect( "localhost", "r
如何将相同的列(在本例中按“级别”排序)放在一起?我正在做一个高分,我从我的数据库中按级别列出它们。如果他们处于同一级别,我希望他们具有相同的 ID。 但是我不想在别人身上显示ID。只有第一个。这是一
我是一名优秀的程序员,十分优秀!