mod の下での指数計算をするアルゴリズムを python で書いてみる

整数  a と整数 n に対して, a^n を整数p で割った余りを出力します. 安直に an 回乗じてからpで割れば良いのではないかと考えたくなりますが, この方法には2つの欠点があります.

  1. n 回乗じるのに計算量が O(n) かかってしまう.
  2. a^n は処理しきれなくなるくらい大きくなりえる.

この2点を解消する方法について解説します.
まず, 計算量を O(\log{n}) に減らす工夫をしましょう. 具体的な例として, a=3,n=8 の場合, つまり 3^8 を計算する方法を反省してみます. 8=2^3 なので,
\begin{equation}
\begin{split}3^8&=3^{2^3}\\&=3^{2\cdot 2^2}\\&=9^{2^2}\\&=9^{2\cdot 2}\\&=81^2\\&=6561
\end{split}
\end{equation}
となります. 2 乗した数を 2 乗すると 4 乗に, 4 乗した数を 2 乗すると 8 乗に・・・となっていくことから, n2 の累乗で表される場合は非常に少ない計算回数で実行することが出来そうです.

今度は 3^{13} を計算する方法を反省してみます. 13=2^3+2^2+2^0 なので,
\begin{equation}
\begin{split}3^{13}&=3^{2^3+2^2+2^0}\\&=3^{2^3}\times 3^{2^2}\times 3^{2^0}
\\&= 9^{2^2}\times 3^{2^2}\times 3^{2^0}
\\&= (9\times3)^{2^2}\times 3^{2^0}
\\&= 27^{2^2}\times 3^{2^0}
\\&= 729^{2^1}\times 3^{2^0}
\\&= 531441^{2^0}\times 3^{2^0}
\\&= (531441\times 3)^{2^0}
\\&= 1594323^{2^0}
\\&= 1594323
\end{split}
\end{equation}
と計算できます. 計算のやり方は, 132 進表示して, 3^{2^3}\times 3^{2^2}\times 3^{2^0} を手前のものから求めていく際に, べきをおろす回数が等しくなったものは一緒に計算してしまえることを用いています. 以上の方法によって, 計算回数は高々 O(\log{n}) になりました.


次に a^n が大きくなりえる点を合同式の考え方を用いて解消していきます. 整数 N,p に対して,
\begin{equation}
r= N \textrm{ mod } p
\end{equation}
Np で割ったときの余りと定めます. 小学校で初めて割り算を習ったときのように, \text{"割られる数"}=\text{"割る数"}\times\text{"商"}+\text{"余り"} という関係を思い起こせば, Npr の間には,

『ある整数 k が存在して,  N= pk +r

という等式が成立していることがわかります("ある整数 k が存在して"というのは慣れてないと不気味に感じることがありますが, 今の場合商は何でも良いことから"気にしない"ことにしています). 実は, a,b,p に対して重要な公式
\begin{equation}
(a \times b) \textrm{ mod } p = ( (a \textrm{ mod } p ) \times (b \textrm{ mod } p)) \textrm{ mod } p
\end{equation}
が成立します. 証明は,  r_1=a \textrm{ mod } p,  r_2=b \textrm{ mod } p とおくと, ある整数  k_1, k_2 を用いて a = pk_1+r_1, b=pk_2+r_2 となることから,
\begin{equation}
a \times b=p(pk_1k_2+r_1+r_2)+r_1r_2
\end{equation}
一方で
\begin{equation}
(a \textrm{ mod } p ) \times (b \textrm{ mod } p)= r_1r_2
\end{equation}
となることから, 両者を p で割ったときの余りは等しいことにより従います. 以上のことから, 掛け算をするたびに毎回pで割った余りを求めて計算していくことによって, p より大きい数が表れずに済みます.

以下が a^n \textrm{ mod } p を求めるソースコードです.

def power_func(a,n,p):
    bi = str(format(n,"b"))#2進表現に
    res = 1
    for i in range(len(bi)):
        res = (res*res) %p
        if bi[i] == "1":
            res = (res*a) %p
    return res