Rust 如何解决可能的乘法溢出以获得正确的模运算?

Rust 如何解决可能的乘法溢出以获得正确的模运算?,rust,modulo,integer-overflow,int128,Rust,Modulo,Integer Overflow,Int128,我必须执行(a*b)%m,但是a、b和m是128位无符号类型,乘法过程中溢出的可能性很大。我怎样才能得到正确的答案(可能使用%更多) 我试图在Rust中实现模指数函数,其中最大的内置类型是u128(这是我可以使用的最大值)。这三个变量都非常大,因此(a*b)>2^128非常简单。我可以使用a.overflowing\u mul(b)来检测是否发生了溢出,但我不知道如何从溢出结果(可以认为是(a*b)%2^128)返回以获取(a*b)%m 我的模块化指数代码如下所示(目前未添加溢出支持): 从数学

我必须执行
(a*b)%m
,但是
a
b
m
是128位无符号类型,乘法过程中溢出的可能性很大。我怎样才能得到正确的答案(可能使用
%
更多)

我试图在Rust中实现模指数函数,其中最大的内置类型是
u128
(这是我可以使用的最大值)。这三个变量都非常大,因此
(a*b)>2^128
非常简单。我可以使用
a.overflowing\u mul(b)
来检测是否发生了溢出,但我不知道如何从溢出结果(可以认为是
(a*b)%2^128
)返回以获取
(a*b)%m

我的模块化指数代码如下所示(目前未添加溢出支持):

从数学角度来看:

(a*b)%m实际上是(a*b)%b%m
|B=当前基数(2^128)
示例:

//数学
(9 * 13) % 11 = 7
//实数(基数20):
(9*13)%(B=20)%11=6
^^^^^^^^^^^应该是7
(8 * 4) % 14 = 4
(8*4)%(B=16)%14=0
^^^^^^^^^^^应该是4

此实现基于将128位产品拆分为四个64位产品,速度是以下产品的五倍、十倍和2.3倍:

fn mul_mod(a:u128,b:u128,m:u128)->u128{

如果m 64,x&!(!0),我不想马上说“这是不可能的”,但是仅仅使用
num
板条箱中的大整数不是更安全、更容易吗?你有什么好的理由来完成这个练习吗?顺便问一下:(0..e).折叠(…*)对于
e:u128
,您不想做什么,因为CPU的时钟大约需要22895156670个宇宙的生命周期来计时2^128次。请参阅:(这是简单的部分,对乘法溢出没有帮助)。我已经放弃使用
u128
s,现在我只使用
num
板条箱。感谢@AndreyTyukin提供的关于平方和乘法的提示。而不是循环使用
c=add(c,c)
,你能不能使用
checked\u shl
并使用相同的
unwrap\u或_else
?这会提高速度吗?@ARaspiK我不确定你的建议,但听起来不太可能。当前的实现谨慎地维护不变的
c
,而不在内部循环中进行任何模运算(128位模本身有一个超过128位的循环)。
选中\u shl
不会这样做。@ARaspiK我刚刚发现一个单独的~15%的改进。
fn mod_exp(b: u128, e: u128, m: u128) {
    (0..e).fold(1, |x, _| (x * b) % m)
    //                    ^^^^^^^^^^^
}
fn mul_mod(a: u128, b: u128, m: u128) -> u128 {
    if m <= 1 << 64 {
        ((a % m) * (b % m)) % m
    } else {
        let add = |x: u128, y: u128| x.checked_sub(m - y).unwrap_or_else(|| x + y);
        let split = |x: u128| (x >> 64, x & !(!0 << 64));
        let (a_hi, a_lo) = split(a);
        let (b_hi, b_lo) = split(b);
        let mut c = a_hi * b_hi % m;
        let (d_hi, d_lo) = split(a_lo * b_hi);
        c = add(c, d_hi);
        let (e_hi, e_lo) = split(a_hi * b_lo);
        c = add(c, e_hi);
        for _ in 0..64 {
            c = add(c, c);
        }
        c = add(c, d_lo);
        c = add(c, e_lo);
        let (f_hi, f_lo) = split(a_lo * b_lo);
        c = add(c, f_hi);
        for _ in 0..64 {
            c = add(c, c);
        }
        add(c, f_lo)
    }
}