gpt4 book ai didi

c++ - 模块化算法和 NTT(有限域 DFT)优化

转载 作者:可可西里 更新时间:2023-11-01 16:19:26 31 4
gpt4 key购买 nike

我想使用 NTT 进行快速平方(参见 Fast bignum square computation),但即使对于非常大的数字,结果也很慢......超过 12000 位。

所以我的问题是:

  • 有没有办法优化我的 NTT 转换?
    我并不是要通过并行(线程)来加速它;这只是低级层。
  • 有没有办法加快我的模块化算法?

  • 这是我在 C++ 中为 NTT 编写的(已经优化的)源代码(它是完整的并且 100% 在 C++ 中工作,不需要第三方库,并且还应该是线程安全的。注意源数组被用作临时数组!!! , 它也不能将数组转换为自身)。
    //---------------------------------------------------------------------------
    class fourier_NTT // Number theoretic transform
    {

    public:
    DWORD r,L,p,N;
    DWORD W,iW,rN;
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; }

    // main interface
    void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n])
    void INTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n); // init r,L,p,W,iW,rN
    void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n])

    // Only for testing
    void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n])
    void INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n])

    // DWORD arithmetics
    DWORD shl(DWORD a);
    DWORD shr(DWORD a);

    // Modular arithmetics
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };

    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,W);
    // NTT_slow(dst,src,N,W);
    }

    //---------------------------------------------------------------------------
    void fourier_NTT::INTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,iW);
    for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
    // INTT_slow(dst,src,N,W);
    }

    //---------------------------------------------------------------------------
    bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
    // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
    N=n; // size of vectors [DWORDs]
    W=modpow(r, L); // Wn for NTT
    iW=modpow(r,p-1-L); // Wn for INTT
    rN=modpow(n,p-2 ); // scale for INTT
    return true;
    }

    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
    // reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];
    // recursion
    NTT_fast(src ,dst ,n2,w2); // even
    NTT_fast(src+n2,dst+n2,n2,w2); // odd
    // restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
    {
    a0=src[i];
    a1=modmul(src[j],w2);
    dst[i]=modadd(a0,a1);
    dst[j]=modsub(a0,a1);
    }
    }

    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
    {
    a=0;
    for (wi=1,i=0;i<n;i++)
    {
    a=modadd(a,modmul(wi,src[i]));
    wi=modmul(wi,wj);
    }
    dst[j]=a;
    wj=modmul(wj,w);
    }
    }

    //---------------------------------------------------------------------------
    void fourier_NTT::INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wi=1,wj=1,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
    {
    a=0;
    for (wi=1,i=0;i<n;i++)
    {
    a=modadd(a,modmul(wi,src[i]));
    wi=modmul(wi,wj);
    }
    dst[j]=modmul(a,rN);
    wj=modmul(wj,iW);
    }
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; }
    DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::mod(DWORD a)
    {
    DWORD bb;
    for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb));
    for (;;)
    {
    if (DWORD(a)>=DWORD(bb)) a-=bb;
    if (bb==p) break;
    bb =shr(bb);
    }
    return a;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    a=mod(a);
    b=mod(b);
    d=a+b;
    cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000;
    if (cy) d-=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    a=mod(a);
    b=mod(b);
    d=a-b; if (DWORD(a)<DWORD(b)) d+=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    { // b bez orezania !
    int i;
    DWORD d;
    a=mod(a);
    for (d=0,i=0;i<32;i++)
    {
    if (DWORD(a&1)) d=modadd(d,b);
    a=shr(a);
    b=modadd(b,b);
    }
    return d;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    { // a,b bez orezania !
    int i;
    DWORD d=1;
    for (i=0;i<32;i++)
    {
    d=modmul(d,d);
    if (DWORD(b&0x80000000)) d=modmul(d,a);
    b=shl(b);
    }
    return d;
    }
    //---------------------------------------------------------------------------

    我的 NTT 类的用法示例:
    fourier_NTT ntt;
    const DWORD n=32
    DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N];

    ntt.NTT(z,x,N); // z[N]=NTT(x[N]), also init constants for N
    ntt.NTT(x,y); // x[N]=NTT(y[N]), no recompute of constants, use last N
    // modular convolution y[]=z[].x[]
    for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]);
    ntt.INTT(x,y); // x[N]=INTT(y[N]), no recompute of constants, use last N
    // x[]=convolution of original x[].y[]

    优化前的一些测量(非 NTT 类):
    a = 0.98765588997654321000 | 389*32 bits
    looped 1x times
    sqr1[ 3.177 ms ] fast sqr
    sqr2[ 720.419 ms ] NTT sqr
    mul1[ 5.588 ms ] simpe mul
    mul2[ 3.172 ms ] karatsuba mul
    mul3[ 1053.382 ms ] NTT mul

    我优化后的一些测量(当前代码,较低的递归参数大小/计数,以及更好的模块化算法):
    a = 0.98765588997654321000 | 389*32 bits
    looped 1x times
    sqr1[ 3.214 ms ] fast sqr
    sqr2[ 208.298 ms ] NTT sqr
    mul1[ 5.564 ms ] simpe mul
    mul2[ 3.113 ms ] karatsuba mul
    mul3[ 302.740 ms ] NTT mul

    检查 NTT mul 和 NTT sqr 时间(我的优化加快了 3 倍多一点)。它只有 1 倍循环,所以它不是很精确(误差 ~ 10%),但即使现在加速也很明显(通常我循环它 1000 倍甚至更多,但我的 NTT 太慢了)。

    您可以自由使用我的代码...只需将我的昵称和/或指向此页面的链接保留在某处(代码中的 rem、readme.txt、about 或其他内容)。我希望它有帮助......(我没有在任何地方看到快速 NTT 的 C++ 源代码,所以我不得不自己编写)。对所有接受的 N 进行了统一根测试,请参阅 fourier_NTT::init(DWORD n)功能。

    P.S.:有关 NTT 的更多信息,请参阅 Translation from Complex-FFT to Finite-Field-FFT .此代码基于我在该链接中的帖子。

    [edit1:] 代码中的进一步更改

    我设法通过利用模素数始终为 0xC0000001 并消除不必要的调用来进一步优化我的模算术。由此产生的加速现在令人惊叹(超过 40 倍),并且在大约 1500 * 32 位阈值之后,NTT 乘法比 karatsuba 更快。顺便说一句,我的 NTT 的速度现在与我在 64 位 double 上优化的 DFFT 的速度相同。

    一些测量:
    a = 0.98765588997654321000 | 1553*32bits
    looped 10x times
    mul2[ 28.585 ms ] karatsuba mul
    mul3[ 26.311 ms ] NTT mul

    模块化算术的新源代码:
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a=a;
    _b=b;
    _p=p;
    asm {
    mov eax,_a
    mov ebx,_b
    mul ebx // H(edx),L(eax) = eax * ebx
    mov ebx,_p
    div ebx // eax = H(edx),L(eax) / ebx
    mov _a,edx // edx = H(edx),L(eax) % ebx
    }
    return _a;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    { // b bez orezania!
    int i;
    DWORD d=1;
    if (a>p) a-=p;
    for (i=0;i<32;i++)
    {
    d=modmul(d,d);
    if (DWORD(b&0x80000000)) d=modmul(d,a);
    b<<=1;
    }
    return d;
    }

    //---------------------------------------------------------------------------

    如您所见,函数 shlshr不再使用。我认为 modpow 可以进一步优化,但它不是一个关键函数,因为它只被调用了很少的次数。最关键的函数是 modmul,它似乎处于最佳状态。

    更多问题:
  • 还有其他选项可以加速 NTT 吗?
  • 我对模块化算法的优化安全吗? (结果似乎是一样的,但我可能会遗漏一些东西。)

  • [edit2] 新优化
    a = 0.99991970486 | 2000*32 bits
    looped 10x
    sqr1[ 13.908 ms ] fast sqr
    sqr2[ 13.649 ms ] NTT sqr
    mul1[ 19.726 ms ] simpe mul
    mul2[ 31.808 ms ] karatsuba mul
    mul3[ 19.373 ms ] NTT mul

    我从你的所有评论中实现了所有可用的东西(感谢你的洞察力)。

    加速:
  • +2.5% 通过移除不必要的安全模组(Mandalf The Beige)
  • +34.9% 使用预先计算的 W,iW 功率(神秘)
  • +35% 总计

  • 实际完整源代码:
    //---------------------------------------------------------------------------
    //--- Number theoretic transforms: 2.03 -------------------------------------
    //---------------------------------------------------------------------------
    #ifndef _fourier_NTT_h
    #define _fourier_NTT_h
    //---------------------------------------------------------------------------
    //---------------------------------------------------------------------------
    class fourier_NTT // Number theoretic transform
    {
    public:
    DWORD r,L,p,N;
    DWORD W,iW,rN; // W=(r^L) mod p, iW=inverse W, rN = inverse N
    DWORD *WW,*iWW,NN; // Precomputed (W,iW)^(0,..,NN-1) powers

    // Internals
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; }
    ~fourier_NTT(){ _free(); }
    void _free(); // Free precomputed W,iW powers tables
    void _alloc(DWORD n); // Allocate and precompute W,iW powers tables

    // Main interface
    void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n])
    void iNTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n); // init r,L,p,W,iW,rN
    void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n])
    void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2);

    // Only for testing
    void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n])
    void iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n])

    // Modular arithmetics (optimized, but it works only for p >= 0x80000000!!!)
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };
    //---------------------------------------------------------------------------

    //---------------------------------------------------------------------------
    void fourier_NTT::_free()
    {
    NN=0;
    if ( WW) delete[] WW; WW=NULL;
    if (iWW) delete[] iWW; iWW=NULL;
    }

    //---------------------------------------------------------------------------
    void fourier_NTT::_alloc(DWORD n)
    {
    if (n<=NN) return;
    DWORD *tmp,i,w;
    tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[] WW; WW=tmp; WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w, W); WW[i]=w; }
    tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; }
    NN=n;
    }

    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,WW,1);
    // NTT_fast(dst,src,N,W);
    // NTT_slow(dst,src,N,W);
    }

    //---------------------------------------------------------------------------
    void fourier_NTT::iNTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,iWW,1);
    // NTT_fast(dst,src,N,iW);
    for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
    // iNTT_slow(dst,src,N,W);
    }

    //---------------------------------------------------------------------------
    bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur!!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
    // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
    N=n; // Size of vectors [DWORDs]
    W=modpow(r, L); // Wn for NTT
    iW=modpow(r,p-1-L); // Wn for INTT
    rN=modpow(n,p-2 ); // Scale for INTT
    _alloc(n>>1); // Precompute W,iW powers
    return true;
    }

    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);

    // Reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];

    // Recursion
    NTT_fast(src ,dst ,n2,w2); // Even
    NTT_fast(src+n2,dst+n2,n2,w2); // Odd

    // Restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
    {
    a0=src[i];
    a1=modmul(src[j],w2);
    dst[i]=modadd(a0,a1);
    dst[j]=modsub(a0,a1);
    }
    }

    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1;

    // Reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];

    // Recursion
    i=i2<<1;
    NTT_fast(src ,dst ,n2,w2,i); // Even
    NTT_fast(src+n2,dst+n2,n2,w2,i); // Odd

    // Restore results
    for (i=0,j=n2;i<n2;i++,j++,w2+=i2)
    {
    a0=src[i];
    a1=modmul(src[j],*w2);
    dst[i]=modadd(a0,a1);
    dst[j]=modsub(a0,a1);
    }
    }

    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a;
    for (wj=1,j=0;j<n;j++)
    {
    a=0;
    for (wi=1,i=0;i<n;i++)
    {
    a=modadd(a,modmul(wi,src[i]));
    wi=modmul(wi,wj);
    }
    dst[j]=a;
    wj=modmul(wj,w);
    }
    }

    //---------------------------------------------------------------------------
    void fourier_NTT::iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wi=1,wj=1,a;
    for (wj=1,j=0;j<n;j++)
    {
    a=0;
    for (wi=1,i=0;i<n;i++)
    {
    a=modadd(a,modmul(wi,src[i]));
    wi=modmul(wi,wj);
    }
    dst[j]=modmul(a,rN);
    wj=modmul(wj,iW);
    }
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a=a;
    _b=b;
    _p=p;
    asm {
    mov eax,_a
    mov ebx,_b
    mul ebx // H(edx),L(eax) = eax * ebx
    mov ebx,_p
    div ebx // eax = H(edx),L(eax) / ebx
    mov _a,edx // edx = H(edx),L(eax) % ebx
    }
    return _a;
    }

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    { // b is not mod(p)!
    int i;
    DWORD d=1;
    //if (a>p) a-=p;
    for (i=0;i<32;i++)
    {
    d=modmul(d,d);
    if (DWORD(b&0x80000000)) d=modmul(d,a);
    b<<=1;
    }
    return d;
    }
    //---------------------------------------------------------------------------
    //---------------------------------------------------------------------------
    #endif
    //---------------------------------------------------------------------------
    //---------------------------------------------------------------------------

    通过分离 NTT_fast 仍然有可能使用更少的堆垃圾。到两个功能。一带 WW[]另一个是 iWW[]这导致在递归调用中减少一个参数。但我对它的期望并不高(仅限 32 位指针),而是有一个功能可以在 future 更好地管理代码。许多功能现在处于休眠状态(用于测试)如慢变体, mod和较旧的快速函数(使用 w 参数而不是 *w2,i2 )。

    为避免大数据集溢出,请将输入数字限制为 p/4位。 哪里 p是每 的位数NTT element 所以对于这个 32 位版本使用 max (32 bit/4 -> 8 bit)输入值。

    [edit3] 简单字符串 bigint测试乘法
    //---------------------------------------------------------------------------
    char* mul_NTT(const char *sx,const char *sy)
    {
    char *s;
    int i,j,k,n;
    // n = min power of 2 <= 2 max length(x,y)
    for (i=0;sx[i];i++); for (n=1;n<i;n<<=1); i--;
    for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--;
    DWORD *x,*y,*xx,*yy,a;
    x=new DWORD[n]; xx=new DWORD[n];
    y=new DWORD[n]; yy=new DWORD[n];

    // Zero padding
    for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0;
    for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0;

    //NTT
    fourier_NTT ntt;
    ntt.NTT(xx,x,n);
    ntt.NTT(yy,y);

    // Convolution
    for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]);

    //INTT
    ntt.iNTT(yy,xx);

    //suma
    a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0;
    delete[] x; delete[] xx;
    delete[] y; delete[] yy;

    return s;
    }
    //---------------------------------------------------------------------------

    我用 AnsiString的,所以我把它移植到 char*希望我没有做错。看起来它工作正常(与 AnsiString 版本相比)。
  • sx,sy是十进制整数
  • 返回分配的字符串 (char*)=sx*sy

  • 每 32 位数据字只有 ~4 位,因此没有溢出的风险,但当然速度较慢。在我的 bignum lib 我使用二进制表示并使用 8 bit 每 32 位 WORD 的块数NTT .如果 N,则风险更大大...

    玩得开心

    最佳答案

    首先,非常感谢您发布并免费使用它。我真的很感激。

    我能够使用一些技巧来消除一些分支,重新排列主循环并修改程序集,并且能够获得 1.35 倍的加速。

    另外,我为 64 位添加了一个预处理器条件,因为 Visual Studio 不允许在 64 位模式下进行内联程序集(谢谢微软;你可以自己动手)。

    当我优化 modsub() 函数时发生了一些奇怪的事情。我像 modadd 一样使用 bit hacks 重写了它(速度更快)。但出于某种原因,modsub 的位明智版本较慢。不知道为什么。可能只是我的电脑。

    //
    // Mandalf The Beige
    // Based on:
    // Spektre
    // http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations
    //
    // This code may be freely used however you choose, so long as it is accompanied by this notice.
    //




    #ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
    #define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR

    #include <string.h>

    #ifndef uint32
    #define uint32 unsigned long int
    #endif

    #ifndef uint64
    #define uint64 unsigned long long int
    #endif


    class fast_ntt // number theoretic transform
    {
    public:
    fast_ntt()
    {
    r = 0; L = 0;
    W = 0; iW = 0; rN = 0;
    }
    // main interface
    void NTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast NTT(uint32 src[n])
    void INTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast INTT(uint32 src[n])
    // helper functions

    private:
    bool init(uint32 n); // init r,L,p,W,iW,rN
    void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])

    void NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
    void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
    // only for testing
    void NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow NTT(uint32 src[n])
    void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow INTT(uint32 src[n])
    // uint32 arithmetics


    // modular arithmetics
    inline uint32 modadd(uint32 a, uint32 b);
    inline uint32 modsub(uint32 a, uint32 b);
    inline uint32 modmul(uint32 a, uint32 b);
    inline uint32 modpow(uint32 a, uint32 b);

    uint32 r, L, N;//, p;
    uint32 W, iW, rN;

    const uint32 p = 0xC0000001;
    };

    //---------------------------------------------------------------------------
    void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
    {
    if (n > 0)
    {
    init(n);
    }
    NTT_fast(dst, src, N, W);
    // NTT_slow(dst,src,N,W);
    }

    //---------------------------------------------------------------------------
    void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n)
    {
    if (n > 0)
    {
    init(n);
    }
    NTT_fast(dst, src, N, iW);
    for (uint32 i = 0; i<N; i++)
    {
    dst[i] = modmul(dst[i], rN);
    }
    // INTT_slow(dst,src,N,W);
    }

    //---------------------------------------------------------------------------
    bool fast_ntt::init(uint32 n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r = 2;
    //p = 0xC0000001;
    if ((n < 2) || (n > 0x10000000))
    {
    r = 0; L = 0; W = 0; // p = 0;
    iW = 0; rN = 0; N = 0;
    return false;
    }
    L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
    // r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    // r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    // r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
    N = n; // size of vectors [uint32s]
    W = modpow(r, L); // Wn for NTT
    iW = modpow(r, p - 1 - L); // Wn for INTT
    rN = modpow(n, p - 2); // scale for INTT
    return true;
    }

    //---------------------------------------------------------------------------

    void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    if(n > 1)
    {
    if(dst != src)
    {
    NTT_calc(dst, src, n, w);
    }
    else
    {
    uint32* temp = new uint32[n];
    NTT_calc(temp, src, n, w);
    memcpy(dst, temp, n * sizeof(uint32));
    delete [] temp;
    }
    }
    else if(n == 1)
    {
    dst[0] = src[0];
    }
    }

    void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w)
    {
    if (n > 1)
    {
    uint32* temp = new uint32[n];
    memcpy(temp, src, n * sizeof(uint32));
    NTT_calc(dst, temp, n, w);
    delete[] temp;
    }
    else if (n == 1)
    {
    dst[0] = src[0];
    }
    }



    void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    if(n > 1)
    {
    uint32 i, j, a0, a1,
    n2 = n >> 1,
    w2 = modmul(w, w);

    // reorder even,odd
    for (i = 0, j = 0; i < n2; i++, j += 2)
    {
    dst[i] = src[j];
    }
    for (j = 1; i < n; i++, j += 2)
    {
    dst[i] = src[j];
    }
    // recursion
    if(n2 > 1)
    {
    NTT_calc(src, dst, n2, w2); // even
    NTT_calc(src + n2, dst + n2, n2, w2); // odd
    }
    else if(n2 == 1)
    {
    src[0] = dst[0];
    src[1] = dst[1];
    }

    // restore results

    w2 = 1, i = 0, j = n2;
    a0 = src[i];
    a1 = src[j];
    dst[i] = modadd(a0, a1);
    dst[j] = modsub(a0, a1);
    while (++i < n2)
    {
    w2 = modmul(w2, w);
    j++;
    a0 = src[i];
    a1 = modmul(src[j], w2);
    dst[i] = modadd(a0, a1);
    dst[j] = modsub(a0, a1);
    }
    }
    }

    //---------------------------------------------------------------------------
    void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    uint32 i, j, wj, wi, a,
    n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
    a = 0;
    for (wi = 1, i = 0; i < n; i++)
    {
    a = modadd(a, modmul(wi, src[i]));
    wi = modmul(wi, wj);
    }
    dst[j] = a;
    wj = modmul(wj, w);
    }
    }

    //---------------------------------------------------------------------------
    void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1;

    for (wj = 1, j = 0; j < n; j++)
    {
    a = 0;
    for (wi = 1, i = 0; i < n; i++)
    {
    a = modadd(a, modmul(wi, src[i]));
    wi = modmul(wi, wj);
    }
    dst[j] = modmul(a, rN);
    wj = modmul(wj, iW);
    }
    }


    //---------------------------------------------------------------------------
    uint32 fast_ntt::modadd(uint32 a, uint32 b)
    {
    uint32 d;
    d = a + b;

    if(d < a)
    {
    d -= p;
    }
    if (d >= p)
    {
    d -= p;
    }
    return d;
    }

    //---------------------------------------------------------------------------
    uint32 fast_ntt::modsub(uint32 a, uint32 b)
    {
    uint32 d;
    d = a - b;
    if (d > a)
    {
    d += p;
    }
    return d;
    }

    //---------------------------------------------------------------------------
    uint32 fast_ntt::modmul(uint32 a, uint32 b)
    {
    uint32 _a = a;
    uint32 _b = b;

    // Original
    uint32 _p = p;
    __asm
    {
    mov eax, _a;
    mul _b;
    div _p;
    mov eax, edx;
    };
    }


    uint32 fast_ntt::modpow(uint32 a, uint32 b)
    {
    //*
    uint64 D, M, A, P;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
    A = modmul(A, A);
    //A = (A * A) % P;

    if ((b & 1) == 1)
    {
    //D = (D * A) % P;
    D = modmul(D, A);
    }
    }
    return (uint32)D;
    }

    新模块
    uint32 fast_ntt::modmul(uint32 a, uint32 b)
    {
    uint32 _a = a;
    uint32 _b = b;

    __asm
    {
    mov eax, a;
    mul b;
    mov ebx, eax;
    mov eax, 2863311530;
    mov ecx, edx;
    mul edx;
    shld edx, eax, 1;
    mov eax, 3221225473;

    mul edx;
    sub ebx, eax;
    mov eax, 3221225473;
    sbb ecx, edx;
    jc addback;

    neg ecx;
    and ecx, eax;
    sub ebx, ecx;

    sub ebx, eax;
    sbb edx, edx;
    and eax, edx;
    addback:
    add eax, ebx;
    };
    }

    [编辑]
    Spektre,根据您的反馈,我将 modadd 和 modsub 改回了原来的样子。我还意识到我对递归 NTT 函数做了一些不应该的更改。

    [编辑2]
    删除了不需要的 if 语句和按位函数。

    [编辑3]
    添加了新的 modmul 内联程序集。

    关于c++ - 模块化算法和 NTT(有限域 DFT)优化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/18577076/

    31 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com