説明

計算量

  • クエリ $O(\log |S|)$

実装例

exist は子供以下に追加された文字列の個数, accept はそのノードにマッチする全ての追加された文字列の番号が格納される。

coutn_less($x)$ は $x$ 未満の値の個数, get_kth($k$) は $k$ 番目(1-indexed)に小さい値を返す。mex_query() は追加される値が相異なる必要がある(重複があってもマージを真面目にやればできるはずだけど実装してない)

template< typename T, int MAX_LOG >
struct BinaryTrie {
  BinaryTrie *nxt[2];
  T lazy;
  int exist;
  bool fill;
  vector< int > accept;

  BinaryTrie() : exist(0), lazy(0), nxt{nullptr, nullptr} {}

  void add(const T &bit, int bit_index, int id) {
    propagate(bit_index);
    if(bit_index == -1) {
      ++exist;
      accept.push_back(id);
    } else {
      auto &to = nxt[(bit >> bit_index) & 1];
      if(!to) to = new BinaryTrie();
      to->add(bit, bit_index - 1, id);
      ++exist;
    }
  }

  void add(const T &bit, int id) {
    add(bit, MAX_LOG, id);
  }

  void add(const T &bit) {
    add(bit, exist);
  }

  void del(const T &bit, int bit_index) {
    propagate(bit_index);
    if(bit_index == -1) {
      exist--;
    } else {
      nxt[(bit >> bit_index) & 1]->del(bit, bit_index - 1);
      exist--;
    }
  }

  void del(const T &bit) {
    del(bit, MAX_LOG);
  }


  pair< T, BinaryTrie * > max_element(int bit_index) {
    propagate(bit_index);
    if(bit_index == -1) return {0, this};
    if(nxt[1] && nxt[1]->size()) {
      auto ret = nxt[1]->max_element(bit_index - 1);
      ret.first |= T(1) << bit_index;
      return ret;
    } else {
      return nxt[0]->max_element(bit_index - 1);
    }
  }

  pair< T, BinaryTrie * > min_element(int bit_index) {
    propagate(bit_index);
    if(bit_index == -1) return {0, this};
    if(nxt[0] && nxt[0]->size()) {
      return nxt[0]->min_element(bit_index - 1);
    } else {
      auto ret = nxt[1]->min_element(bit_index - 1);
      ret.first |= T(1) << bit_index;
      return ret;
    }
  }

  T mex_query(int bit_index) { // distinct-values
    propagate(bit_index);
    if(bit_index == -1 || !nxt[0]) return 0;
    if(nxt[0]->size() == (T(1) << bit_index)) {
      T ret = T(1) << bit_index;
      if(nxt[1]) ret |= nxt[1]->mex_query(bit_index - 1);
      return ret;
    } else {
      return nxt[0]->mex_query(bit_index - 1);
    }
  }

  int64_t count_less(const T &bit, int bit_index) {
    propagate(bit_index);
    if(bit_index == -1) return 0;
    int64_t ret = 0;
    if((bit >> bit_index) & 1) {
      if(nxt[0]) ret += nxt[0]->size();
      if(nxt[1]) ret += nxt[1]->count_less(bit, bit_index - 1);
    } else {
      if(nxt[0]) ret += nxt[0]->count_less(bit, bit_index - 1);
    }
    return ret;
  }

  pair< T, BinaryTrie * > get_kth(int64_t k, int bit_index) { // 1-indexed
    propagate(bit_index);
    if(bit_index == -1) return {0, this};
    if((nxt[0] ? nxt[0]->size() : 0) < k) {
      auto ret = nxt[1]->get_kth(k - (nxt[0] ? nxt[0]->size() : 0), bit_index - 1);
      ret.first |= T(1) << bit_index;
      return ret;
    } else {
      return nxt[0]->get_kth(k, bit_index - 1);
    }
  }

  pair< T, BinaryTrie * > max_element() {
    assert(exist);
    return max_element(MAX_LOG);
  }

  pair< T, BinaryTrie * > min_element() {
    assert(exist);
    return min_element(MAX_LOG);
  }

  T mex_query() {
    return mex_query(MAX_LOG);
  }

  int size() const {
    return exist;
  }

  void xorpush(const T &bit) {
    lazy ^= bit;
  }

  int64_t count_less(const T &bit) {
    return count_less(bit, MAX_LOG);
  }

  pair< T, BinaryTrie * > get_kth(int64_t k) {
    assert(0 < k && k <= size());
    return get_kth(k, MAX_LOG);
  }

  void propagate(int bit_index) {
    if((lazy >> bit_index) & 1) swap(nxt[0], nxt[1]);
    if(nxt[0]) nxt[0]->lazy ^= lazy;
    if(nxt[1]) nxt[1]->lazy ^= lazy;
    lazy = 0;
  }
};

応用1: 永続Trie

かきなおす

Trie は木なので比較的容易に永続できる。以下では、2進Trieをポインタベースで実装し、永続している。

template< typename T >
struct BinaryTrieNode {
  using Node = BinaryTrieNode< T >;
 
  BinaryTrieNode< T > *nxt[2];
  int max_index;
 
  BinaryTrieNode() : max_index(-1) {
    nxt[0] = nxt[1] = nullptr;
  }
 
  void update_direct(int id) {
    max_index = max(max_index, id);
  }
 
  void update_child(Node *child, int id) {
    max_index = max(max_index, id);
  }
 
  Node *add(const T &bit, int bit_index, int id, bool need = true) {
    Node *node = need ? new Node(*this) : this;
    if(bit_index == -1) {
      node->update_direct(id);
    } else {
      const int c = (bit >> bit_index) & 1;
      if(node->nxt[c] == nullptr) node->nxt[c] = new Node(), need = false;
      node->nxt[c] = node->nxt[c]->add(bit, bit_index - 1, id, need);
      node->update_child(node->nxt[c], id);
    }
    return node;
  }
 
  inline T min_query(T bit, int bit_index, int bit2, int l) {
    if(bit_index == -1) return bit;
    int c = (bit2 >> bit_index) & 1;
    if(nxt[c] != nullptr && l <= nxt[c]->max_index) {
      return nxt[c]->min_query(bit, bit_index - 1, bit2, l);
    } else {
      return nxt[1 ^ c]->min_query(bit | (1LL << bit_index), bit_index - 1, bit2, l);
    }
  }
};
 
template< typename T, int MAX_LOG >
struct PersistentBinaryTrie {
  using Node = BinaryTrieNode< T >;
  Node *root;
 
  PersistentBinaryTrie(Node *root) : root(root) {}
 
  PersistentBinaryTrie() : root(new Node()) {}
 
  PersistentBinaryTrie add(const T &bit, int id) {
    return PersistentBinaryTrie(root->add(bit, MAX_LOG, id));
  }
 
  T min_query(int bit, int l) {
    return root->min_query(0, MAX_LOG, bit, l);
  }
};