Baby Stepsなブログ

競プロとか。間違ったこと書いてあったら@pi0nep1oneにご連絡ください。

ダブリングによる最近共通祖先(Lowest Common Ancestor)のライブラリを作った

前提

ダブリング手順

  • dp[i][j]:=頂点iの2j個上の頂点
    • iは頂点数nに対して、log_n
    • 初期化: まずdpを-1で初期化したのち、dp[0][j]、つまり1個上のノードを登録する
    • 以降、dp[i+1][j]=dp[i][dp[i][j]] で更新

ダブリングのわかりやすい解説

satanic0258.hatenablog.com

algo-logic.info

LCA手順

  • まず与えられる(u, v)の深さを揃える(深い方をrootに向けて移動させる。ここでは uが深いという前提)
    • 揃える処理をナイーブに実装するなら O(n): nは頂点数
    • 深さの差をdとし2進表記した時、立っているbitの位置をkとすると、u=dp[k][u] という更新を繰り返すことで深さを揃えられる
    • この操作回数は高々log_n
  • 深さを揃えた結果、u, v が一致するならLCAはu(またはv)
  • そうでなければ、u, vをrootに向けてLCAの一歩手前まで移動させていく
    • この操作もナイーブに実装すると O(n)
    • ダブリングで求めたテーブルを使い、iを大きい方から始めて、dp[i][u]!=dp[i][v] となる場合、u,vを2i個分移動させることを繰り返す。最終的なu,vの1個上のノードがLCAになる
    • 上の操作回数は高々log_n

なぜdだけ移動させるのに、dのbitに着目すると良いの?

  • 繰り返し二乗法と同じ
  • 例で考えるとわかりやすい
    • d=5だけ移動させたい
    • 5=(101) であり、立っているbitに着目すると、20+22
    • ダブリングで用意したテーブルを使うことで、↑の操作を再現でき、その操作回数は ⌊log_d⌋ + 1 回になっている

LCAのわかりやすい解説

ダブリングによる木の最近共通祖先(LCA:Lowest Common Ancestor)を求めるアルゴリズム | アルゴリズムロジックalgo-logic.info

検証用問題

onlinejudge.u-aizu.ac.jp

実装

/**
 * Lowest Common Ancestor
 */
class LCA{
  private:
  int root;
  int k; // n<=2^kとなる最小のk
  vector<vector<int>> dp; // dp[i][j]:=要素jの2^i上の要素
  vector<int> depth;  // depth[i]:=rootに対する頂点iの深さ
  
  public:
  LCA(const vector<vector<int>>& _G, const int _root=0){
    int n=_G.size();
    root=_root;
    k=1;
    int nibeki=2;
    while(nibeki<n){
      nibeki<<=1;
      k++;
    }
    // 頂点iの親ノードを初期化
    dp = vector<vector<int>>(k+1, vector<int>(n, -1));
    depth.resize(n);
    function<void(int, int)> _dfs=[&](int v, int p){
      dp[0][v]=p;
      for(auto nv: _G[v]){
        if(nv==p) continue;
        depth[nv]=depth[v]+1;
        _dfs(nv, v); 
      }
    };
    _dfs(root, -1);
    // ダブリング
    for(int i=0; i<k; i++){
      for(int j=0; j<n; j++){
        if(dp[i][j]==-1) continue;
        dp[i+1][j]=dp[i][dp[i][j]];
      }
    }
  }
  
  /// get LCA
  int get(int u, int v){
    if(depth[u]<depth[v]) swap(u,v); // u側を深くする
    if(depth[u]!=depth[v]){
      long long d=depth[u]-depth[v];
      for(int i=0; i<k; i++) if((d>>i)&1) u=dp[i][u];
    }
    if(u==v) return u;
    
    for(int i=k; i>=0; i--){
      if(dp[i][u]!=dp[i][v]){
        u=dp[i][u], v=dp[i][v];
      }
    }
    return dp[0][u];
  }
  
  int get_distance(const int u, const int v){
    int lca=get(u,v);
    return depth[u]+depth[v]-2*depth[lca];
  }
};