蒙哥马利模乘算法

Created At 2022-07-01 16:19
Updated At 2023-10-14 23:51 题记, 前言, 实现

蒙哥马利模乘 (Montgomery Modular Multiplication)

数学不会欺骗你, 但是计算机会欺骗你.
----Micheal

前言

这句话是调试代码时, 我对一个异常的结果用计算器验证发现和程序跑的结果不同时的感慨.

当时大概是没有用减法进行取模以至于爆数据范围了, 所以计算结果出错.

但是只要经过证明, 数学的定理是不会失效的, 只要计算的每一步正确答案就一定是正确的.

引言

对于大整数的取模乘法, 我们需要对乘积做除法, 如果程序中有大量的取模乘法, 常数必然是极大的, 所以我们考虑如何优化这个常数, 这样就引入了一个算法, 蒙哥马利模乘算法.

由于中文 wiki 词条十分简短, 所以我只能去啃生肉, 进行基本的原理解释.

前情提要: 算法竞赛无需如此卡常, 只是感觉比较有趣, 所以进行理论介绍.

蒙哥马利形式 (Montgomery Form)

我们尝试用一个范围内的整数, 映射到 [0,Mod)[0, Mod) 范围中的每个整数中.

这就需要引入一个数字 R>ModR > Mod, 满足 gcd(Mod,R)=1\gcd(Mod, R) = 1. 设 RRModMod 意义下的乘法逆元为 RR'. 这时我们知道对于所有 a[0,Mod)a \in [0, Mod), 模 ModMod 意义下 aR\frac aR 都是互不相同的.

尝试证明之, 假设有 a,b[0,Mod),aba, b \in [0, Mod), a \neq b, 则 aRaR(modMod)\frac aR \equiv aR' \pmod {Mod}, bRbR(modMod)\frac bR \equiv bR' \pmod {Mod}. 如果 aRbR(modMod)\frac aR \equiv \frac bR \pmod {Mod}, 则:

aRbR(modMod)aRbR0(modMod)(ab)R0(modMod)\begin{aligned} aR' &\equiv bR' &\pmod {Mod}\\ aR' - bR' &\equiv 0 &\pmod {Mod}\\ (a - b)R' &\equiv 0 &\pmod {Mod}\\ \end{aligned}

因为 gcd(R,Mod)=1\gcd(R, Mod) = 1, 因此 R[1,Mod)R' \in [1, Mod), 所以 ab=0a - b = 0, 和一开始假设的 aba \neq b 矛盾, 所以假设不成立. 也就是说可以通过模 ModMod 意义下 aR\frac aR 的值唯一地确定 aa.

我们称模 ModMod 意义下的 aRaRaa 的蒙哥马利形式, 可以通过 aa 的蒙哥马利形式唯一地确定 aa.

蒙哥马利形式的乘法

如果我们需要计算 abab, 则可以先考虑计算结果的蒙哥马利形式, 即 abRabR, 然后再得到 abab.

先提前将所有数字转化为蒙哥马利形式, 也就是 aRaR, bRbR, 然后把它们相乘, 得到 abR2abR^2, 然后把它当作 abRabR 的蒙哥马利形式, 即可求出 abRabR.

REDC 算法

这个算法是帮助我们根据 x[0,(R1)(Mod1))x \in [0, (R - 1)(Mod - 1)) 计算模 ModMod 意义下 xR\frac xR 的算法. 提前处理出 ModMod' 表示模 RR 意义下 ModMod 的乘法逆元的相反数, 即 ModMod1(modR)Mod'Mod \equiv -1 \pmod R.

其基本思想是把 xx 转化为 xx(modMod)x' \equiv x \pmod {Mod}, RxR|x', 这样我们就可以直接除以 RR 得到结果了. 把 xx 拆成 x=qR+rx = qR + r 来考虑. (r[0,R))(r \in [0, R))

如果把模 RR 意义下的 xMod- \frac x{Mod} 乘以 ModMod, 记为 r=xMod×Modr' = - \frac x{Mod} \times Mod, 那么它在模 RR 意义下就是 x-x 了, 而且因为有因数 ModMod 所以在模 ModMod 意义下为 00.

这时我们发现 x+rx + r' 在模 ModMod 意义下仍然是 xx, 但是在模 RR 意义下就变成了 00, 也就是说 x=x+rx' = x + r'. 直接把 xx' 除以 RR 得到答案.

可能看到这里很多人疑惑了, 这样更多的取模和除法不会让常数更大吗. 但是我们可以通过将 22 的整数幂当作 RR, 借此可以把取模和除法用位运算代替. 而关于 ModMod 的取模和除法就这样被省略掉了.

蒙哥马利形式的转化

为了快速地将整数转化为蒙哥马利形式, 我们需要处理出 R2R^2ModMod 意义下的结果, 也就是 RR 的蒙哥马利形式, 这样只需要把 xxR2R^2 进行蒙哥马利乘法, 就可以得到 xRxRModMod 意义下的结果了.

实现

我们发现一个问题, 当 Mod<232Mod < 2^{32} 时, 计算 xR\frac xR 时的 xx 的范围是 2642^{64} 级别的, 因为我们需要在两个数字相乘得到 abR2abR^2 的时候对这个数进行操作. 根据 REDC 算法的式子 xRxxModModR(modMod)\frac xR \equiv \dfrac{x - \frac x{Mod}Mod}R \pmod {Mod}, 不难发现我们需要将 2642^{64} 级别的 xx1Mod-\frac 1{Mod} 相乘再取模, 这样就超过 unsigned long long 范围了, 但是我们关心的其实还是它们对 RR 取模的结果, 所以可以先把 xx 取模, 减小到 2322^{32} 以内之后再和 1Mod-\frac 1{Mod} 相乘.

另外因为我们求 xR\frac xR 的时候需要两个 Mod×RMod \times R 级别的数字的加法, 所以在除以 RR 之后, 需要判断是否大于 ModMod 并使用减法进行取模.

下面是用蒙哥马利模乘算法实现的 NTT 求多项式乘法的代码, 但是反向优化了, 可能是有更优秀的高明写法.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
const unsigned long long Mod(998244353), One21(2097151), Cover(1073741823),
InvRevMod(998244351), R_Sq(682155965), R1(75497471), R3(226492413);
inline void Mn(unsigned &x) { x -= (x >= Mod) ? Mod : 0; }
inline void Mn(unsigned long long &x) { x -= (x >= Mod) ? Mod : 0; }
inline unsigned Mned(unsigned x) { return x - ((x >= Mod) ? Mod : 0); }
inline unsigned AR_to_A(unsigned long long x) { // Less than Mod
return Mned((x + (((x & Cover) * InvRevMod) & Cover) * Mod) >> 30);
}
inline unsigned Mult(unsigned long long x, unsigned long long y) {
return AR_to_A(x * y);
}
inline unsigned A_to_AR(unsigned long long x) { return AR_to_A(x * R_Sq); }
inline unsigned long long Pow(unsigned long long x, unsigned y) {
unsigned long long Rt(R1);
while (y) {
if (y & 1) Rt = Mult(Rt, x);
y >>= 1, x = Mult(x, x);
}
return Rt;
}
unsigned W[2097152], IW[2097152];
unsigned a[2097152], b[2097152];
unsigned n, m, l;
unsigned A, B, C, D, t;
unsigned Cnt(0), Ans(0), Tmp(0);
void Init() {
unsigned long long w(Pow(R3, (Mod - 1) >> 21));
W[0] = IW[0] = R1;
for (unsigned i(1); !(i >> 21); ++i) W[i] = Mult(W[i - 1], w);
w = Pow(w, Mod - 2);
for (unsigned i(1); !(i >> 21); ++i) IW[i] = Mult(IW[i - 1], w);
}
void DIT(unsigned *F, unsigned N) { // Len = 2^N
for (unsigned i(1), I(1 << 20); !(i >> N); i <<= 1, I >>= 1) {
for (unsigned j(0), J(0); !(j >> N); ++j, J = ((J + I) & One21))
if (!(j & i)) {
unsigned long long TmA(F[j]), TmB(Mult(F[j ^ i], W[J]));
Mn(F[j] = TmA + TmB);
Mn(F[j ^ i] = Mod + TmA - TmB);
}
}
}
void DIF(unsigned *F, unsigned N) { // Len = 2^N
for (unsigned i(1 << (N - 1)), I(1 << (21 - N)); i; i >>= 1, I <<= 1) {
for (unsigned j(0), J(0); !(j >> N); ++j, J = ((J + I) & One21))
if (!(j & i)) {
unsigned long long TmA(F[j]), TmB(F[j ^ i]);
Mn(F[j] = TmA + TmB), Mn(TmB = Mod + TmA - TmB);
F[j ^ i] = Mult(TmB, IW[J]);
}
}
}
unsigned Tms(unsigned *F, unsigned *G, unsigned lFG) {
unsigned N(0);
while ((lFG - 1) >> N) ++N;
DIF(F, N), DIF(G, N);
for (unsigned i((1 << N) - 1); ~i; --i) F[i] = Mult(F[i], G[i]);
DIT(F, N);
unsigned long long IN(Pow(A_to_AR(1 << N), Mod - 2));
for (unsigned i((1 << N) - 1); ~i; --i) F[i] = Mult(F[i], IN);
return lFG;
}
signed main() {
Init();
n = RD() + 1, m = RD() + 1;
for (unsigned i(0); i < n; ++i) a[i] = A_to_AR(RD());
for (unsigned i(0); i < m; ++i) b[i] = A_to_AR(RD());
l = Tms(a, b, n + m - 1);
for (unsigned i(0); i < l; ++i) printf("%u ", AR_to_A(a[i]));
putchar(0x0A);
return Wild_Donkey;
}

参考

Montgomery modular multiplication, Wikipedia

所以 Wikipedia 快出中文词条吧.