メモの日々


2020年02月27日(木) [長年日記]

[dev][math] 累乗を高速に計算する

累乗の計算はバイナリ法(繰り返し二乗法)と呼ばれるアルゴリズムで高速に計算できる。

再帰で考えるとわかりやすいのでそれをメモ。

\[ \ a^x = \begin{cases} 1 & (x = 0) \\ (a^2)^{\frac{x}{2}} & (\text{$x$が偶数}) \\ a \cdot a^{x-1} & (\text{$x$が奇数}) \end{cases} \]

が成り立つので、これをそのままコードに書けばよい。

素朴な実装と繰り返し二乗法の実行時間を比較するコードを示す。計算結果は巨大な数になるため、1000000007 で割った余りを計算するようにしている。

#include <chrono>
#include <iomanip>
#include <iostream>
#include <string>

// 素朴な実装
long mod_pow0(long a, long x, long m)
{
  long result = 1;
  for (long i = 0; i < x; ++i) {
    result = result * a % m;
  }
  return result;
}

// 繰り返し二乗法
long mod_pow(long a, long x, long m)
{
  if (x <= 0) return 1;
  return x & 1
      ? a * mod_pow(a, x - 1, m) % m
      : mod_pow(a * a % m, x >> 1, m);
}

// 時間計測をする
template<typename F>
void measure(const std::string& name, F f)
{
  namespace ch = std::chrono;
  const auto s = ch::steady_clock::now();
  const auto result = f();
  const auto e = ch::steady_clock::now();
  const auto d = ch::duration_cast<ch::duration<double>>(e - s).count();
  std::cout << std::fixed
      << name << " = " << result
      << " (" << d << " sec)"
      << std::endl;
}

int main()
{
  constexpr long m = 1000000000 + 7;
  long a, x;
  std::cin >> a >> x;

  measure("mod_pow0", [a, x, m]() { return mod_pow0(a, x, m); });
  measure("mod_pow ", [a, x, m]() { return mod_pow(a, x, m); });
}

手元の環境だと、素朴な実装でもxが10**8までなら1秒程度で計算できるが、10**9になると10秒以上かかる。繰り返し二乗法は速い。

$ ./a.out
2 10000000
mod_pow0 = 255718402 (0.128632 sec)
mod_pow  = 255718402 (0.000001 sec)

$ ./a.out
2 100000000
mod_pow0 = 494499948 (1.274563 sec)
mod_pow  = 494499948 (0.000001 sec)

$ ./a.out
2 1000000000
mod_pow0 = 140625001 (13.117252 sec)
mod_pow  = 140625001 (0.000002 sec)