Created At 2022-07-01 16:19
Updated At 2023-10-14 23:51 题记, 前言, 实现
蒙哥马利模乘 (Montgomery Modular Multiplication)
数学不会欺骗你, 但是计算机会欺骗你.
----Micheal
前言
这句话是调试代码时, 我对一个异常的结果用计算器验证发现和程序跑的结果不同时的感慨.
当时大概是没有用减法进行取模以至于爆数据范围了, 所以计算结果出错.
但是只要经过证明, 数学的定理是不会失效的, 只要计算的每一步正确答案就一定是正确的.
引言
对于大整数的取模乘法, 我们需要对乘积做除法, 如果程序中有大量的取模乘法, 常数必然是极大的, 所以我们考虑如何优化这个常数, 这样就引入了一个算法, 蒙哥马利模乘算法.
由于中文 wiki 词条十分简短, 所以我只能去啃生肉, 进行基本的原理解释.
前情提要: 算法竞赛无需如此卡常, 只是感觉比较有趣, 所以进行理论介绍.
我们尝试用一个范围内的整数, 映射到 [0,Mod) 范围中的每个整数中.
这就需要引入一个数字 R>Mod, 满足 gcd(Mod,R)=1. 设 R 模 Mod 意义下的乘法逆元为 R′. 这时我们知道对于所有 a∈[0,Mod), 模 Mod 意义下 Ra 都是互不相同的.
尝试证明之, 假设有 a,b∈[0,Mod),a=b, 则 Ra≡aR′(modMod), Rb≡bR′(modMod). 如果 Ra≡Rb(modMod), 则:
aR′aR′−bR′(a−b)R′≡bR′≡0≡0(modMod)(modMod)(modMod)
因为 gcd(R,Mod)=1, 因此 R′∈[1,Mod), 所以 a−b=0, 和一开始假设的 a=b 矛盾, 所以假设不成立. 也就是说可以通过模 Mod 意义下 Ra 的值唯一地确定 a.
我们称模 Mod 意义下的 aR 为 a 的蒙哥马利形式, 可以通过 a 的蒙哥马利形式唯一地确定 a.
蒙哥马利形式的乘法
如果我们需要计算 ab, 则可以先考虑计算结果的蒙哥马利形式, 即 abR, 然后再得到 ab.
先提前将所有数字转化为蒙哥马利形式, 也就是 aR, bR, 然后把它们相乘, 得到 abR2, 然后把它当作 abR 的蒙哥马利形式, 即可求出 abR.
REDC 算法
这个算法是帮助我们根据 x∈[0,(R−1)(Mod−1)) 计算模 Mod 意义下 Rx 的算法. 提前处理出 Mod′ 表示模 R 意义下 Mod 的乘法逆元的相反数, 即 Mod′Mod≡−1(modR).
其基本思想是把 x 转化为 x′≡x(modMod), R∣x′, 这样我们就可以直接除以 R 得到结果了. 把 x 拆成 x=qR+r 来考虑. (r∈[0,R))
如果把模 R 意义下的 −Modx 乘以 Mod, 记为 r′=−Modx×Mod, 那么它在模 R 意义下就是 −x 了, 而且因为有因数 Mod 所以在模 Mod 意义下为 0.
这时我们发现 x+r′ 在模 Mod 意义下仍然是 x, 但是在模 R 意义下就变成了 0, 也就是说 x′=x+r′. 直接把 x′ 除以 R 得到答案.
可能看到这里很多人疑惑了, 这样更多的取模和除法不会让常数更大吗. 但是我们可以通过将 2 的整数幂当作 R, 借此可以把取模和除法用位运算代替. 而关于 Mod 的取模和除法就这样被省略掉了.
蒙哥马利形式的转化
为了快速地将整数转化为蒙哥马利形式, 我们需要处理出 R2 模 Mod 意义下的结果, 也就是 R 的蒙哥马利形式, 这样只需要把 x 和 R2 进行蒙哥马利乘法, 就可以得到 xR 模 Mod 意义下的结果了.
实现
我们发现一个问题, 当 Mod<232 时, 计算 Rx 时的 x 的范围是 264 级别的, 因为我们需要在两个数字相乘得到 abR2 的时候对这个数进行操作. 根据 REDC 算法的式子 Rx≡Rx−ModxMod(modMod), 不难发现我们需要将 264 级别的 x 和 −Mod1 相乘再取模, 这样就超过 unsigned long long
范围了, 但是我们关心的其实还是它们对 R 取模的结果, 所以可以先把 x 取模, 减小到 232 以内之后再和 −Mod1 相乘.
另外因为我们求 Rx 的时候需要两个 Mod×R 级别的数字的加法, 所以在除以 R 之后, 需要判断是否大于 Mod 并使用减法进行取模.
下面是用蒙哥马利模乘算法实现的 NTT 求多项式乘法的代码, 但是反向优化了, 可能是有更优秀的高明写法.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| const unsigned long long Mod(998244353), One21(2097151), Cover(1073741823), InvRevMod(998244351), R_Sq(682155965), R1(75497471), R3(226492413); inline void Mn(unsigned &x) { x -= (x >= Mod) ? Mod : 0; } inline void Mn(unsigned long long &x) { x -= (x >= Mod) ? Mod : 0; } inline unsigned Mned(unsigned x) { return x - ((x >= Mod) ? Mod : 0); } inline unsigned AR_to_A(unsigned long long x) { return Mned((x + (((x & Cover) * InvRevMod) & Cover) * Mod) >> 30); } inline unsigned Mult(unsigned long long x, unsigned long long y) { return AR_to_A(x * y); } inline unsigned A_to_AR(unsigned long long x) { return AR_to_A(x * R_Sq); } inline unsigned long long Pow(unsigned long long x, unsigned y) { unsigned long long Rt(R1); while (y) { if (y & 1) Rt = Mult(Rt, x); y >>= 1, x = Mult(x, x); } return Rt; } unsigned W[2097152], IW[2097152]; unsigned a[2097152], b[2097152]; unsigned n, m, l; unsigned A, B, C, D, t; unsigned Cnt(0), Ans(0), Tmp(0); void Init() { unsigned long long w(Pow(R3, (Mod - 1) >> 21)); W[0] = IW[0] = R1; for (unsigned i(1); !(i >> 21); ++i) W[i] = Mult(W[i - 1], w); w = Pow(w, Mod - 2); for (unsigned i(1); !(i >> 21); ++i) IW[i] = Mult(IW[i - 1], w); } void DIT(unsigned *F, unsigned N) { for (unsigned i(1), I(1 << 20); !(i >> N); i <<= 1, I >>= 1) { for (unsigned j(0), J(0); !(j >> N); ++j, J = ((J + I) & One21)) if (!(j & i)) { unsigned long long TmA(F[j]), TmB(Mult(F[j ^ i], W[J])); Mn(F[j] = TmA + TmB); Mn(F[j ^ i] = Mod + TmA - TmB); } } } void DIF(unsigned *F, unsigned N) { for (unsigned i(1 << (N - 1)), I(1 << (21 - N)); i; i >>= 1, I <<= 1) { for (unsigned j(0), J(0); !(j >> N); ++j, J = ((J + I) & One21)) if (!(j & i)) { unsigned long long TmA(F[j]), TmB(F[j ^ i]); Mn(F[j] = TmA + TmB), Mn(TmB = Mod + TmA - TmB); F[j ^ i] = Mult(TmB, IW[J]); } } } unsigned Tms(unsigned *F, unsigned *G, unsigned lFG) { unsigned N(0); while ((lFG - 1) >> N) ++N; DIF(F, N), DIF(G, N); for (unsigned i((1 << N) - 1); ~i; --i) F[i] = Mult(F[i], G[i]); DIT(F, N); unsigned long long IN(Pow(A_to_AR(1 << N), Mod - 2)); for (unsigned i((1 << N) - 1); ~i; --i) F[i] = Mult(F[i], IN); return lFG; } signed main() { Init(); n = RD() + 1, m = RD() + 1; for (unsigned i(0); i < n; ++i) a[i] = A_to_AR(RD()); for (unsigned i(0); i < m; ++i) b[i] = A_to_AR(RD()); l = Tms(a, b, n + m - 1); for (unsigned i(0); i < l; ++i) printf("%u ", AR_to_A(a[i])); putchar(0x0A); return Wild_Donkey; }
|
参考
Montgomery modular multiplication, Wikipedia
所以 Wikipedia 快出中文词条吧.