Baby Stepsなブログ

競プロとか。間違ったこと書いてあったら@pi0nep1oneにご連絡ください。

Trie木の実装をライブラリ化した

前提

  • Trie木は、複数の単語(文字列)を登録でき、ある文字列(またはそのprefix)が登録済みであるかを高速に検索できるデータ構造
  • 基本的な機能はシンプルに、insert / searchの2つだけ
  • 実態は有向木
  • 基本的な動作は、単語のprefixが同じならNodeを共有し、異なる文字が現れた際に新規にNodeを追加していく
    • prefixを共有しながら枝分かれしていくイメージ
  • 各Nodeは3つの基本情報を持つ
    1. next: 次のNodeへのポインタ
      • 扱いたい文字種の数char_size(例: a-zなら26)だけ、ポインタを持つ必要がある
    2. commom: そのノードがいくつの単語で共有されているか
      • この情報を持たせることで、そのNode以降にいくつの単語が登録済みかがわかる
      • 故に、Node(0) を確認すればTrie木に登録済みの単語の総数がわかる
    3. accept: そのノードがある単語の末尾文字であるか
      • 単語検索の際にはacceptを確認し、prefix検索の際には確認しない
  • 今Trie木が管理している文字数を知りたい場合は、単にNodeの数をカウントすればよい(Node(0)は文字ではないので、正確にはsize - 1になる)

Trie木のわかりやすい解説

algo-logic.info

検証用問題

leetcode.com

※ 以下の実装を使うには、メソッドのシグネチャとtemplate部分を変更が必要

実装

/**
 * Trie-Tree
 * 
 * @param char_size Trie木で扱う文字種数
 * @param base Trieで扱う文字種のうちの先頭
 */
template<int char_size, int base>
class Trie{
  private:
  struct Node{
    int common;         // この頂点を共有する単語数
    vector<int> next;   // 次の頂点id
    vector<int> accept; // この頂点を終端とする単語idを保持する
    Node(): common(0), next(vector<int>(char_size, -1)) {};
  };
  vector<Node> nodes;
  
  public:
  Trie(): nodes(vector<Node>(1)) {};
  
  void insert(const string& s, int string_id=0){
    nodes[0].common++;
    int last=0;
    for(int i=0; i<s.size(); i++){
      int char_id=(int)(s[i]-base);
      if(nodes[last].next[char_id]==-1){
        nodes.push_back(Node());
        nodes[last].next[char_id]=nodes.size()-1;
        last=nodes.size()-1;
      }else{
        last=nodes[last].next[char_id];
      }
      nodes[last].common++;
    }
    nodes[last].accept.push_back(string_id);
  }
  
  bool search(const string& s){
    int last=0;
    for(int i=0; i<s.size(); i++){
      int char_id=(int)(s[i]-base);
      if(nodes[last].next[char_id]==-1) return false;
      else last=nodes[last].next[char_id];
    }
    return nodes[last].accept.size()>0 ? true : false;
  }
  
  bool search_prefix(const string& s){
    int last=0;
    for(int i=0; i<s.size(); i++){
      int char_id=(int)(s[i]-base);
      if(nodes[last].next[char_id]==-1) return false;
      else last=nodes[last].next[char_id];
    }
    return true;
  }
  
  int count_words(){
    return nodes[0].common;
  }
  
  int size(){
    return nodes.size();
  }
};