Define-By-Run風味の全探索ライブラリを作る
これ何?
Define-By-Run とは
ChainerやOptunaで使われているAPIの作られ方……という認識
例えばChainerであれば「入力画像/ベクトル にいくつか層を適用して出力を計算する」という(ようにみえる)コードを書くことで、内部的にネットワーク構造が作られ、その学習を行うことができる(確か)
例えばOptunaであれば「乱数(っぽいもの)を振って計算し、それを利用して評価値を出力する」という(ようにみえる)コードを書くことで、内部的にその評価値の分布をTPEやCMA-ESにかけて最適化できる(確か)
特にOptunaは例えば「最初に乱数でステップ数を決定して、そのステップ数に応じて処理のクオリティや複雑度を変える」といったことも可能
要は「手続き的に書いているだけで、そのコードの構造や性質を具体的に定義せずとも、ライブラリ側で抽出できる」という手法……と考えられる
今回はこの考えをもとに「Define-By-Run風味の全探索ライブラリ」を作ります。
ただこれをDefine-By-Runと名付けて良いものか正直怪しいので、以降はすべてDBRと略記し区別することにします……(すでに名のあるテクニックだったら教えてください)
作りたいもの
例えば以下の問題を考えましょう(つい最近の問題なのでupsolveしたい人は気をつけてください、このページでは解法には触れません):
https://atcoder.jp/contests/arc168/tasks/arc168_c
小さいケースで検算するために、操作後にありえる文字列を全探索することを考えるとDFSを書きたくなります:
vector<string> bf; void dfs(string s, int k) { bf.push_back(s); if (k == 0) return; int n = (int)s.size(); for (int a = 0; a < n; ++a) { for (int b = 0; b < a; ++b) { swap(s[a], s[b]); dfs(s, k-1); swap(s[a], s[b]); } } } void main() { int n,k; string s; cin>>n>>k>>s; dfs(s,k); sort(ALL(bf)); bf.erase(ALL(bf), bf.end()); }
これ結構冗長だな……って思いますよね、それでこうです:
void main() { int n,k; string s; cin>>n>>k>>s; auto f = [&](DBRSource& source) { string t = s; int swap_num = source.choice(0,k); for (int times = 0; times < swap_num; ++times) { int a = source.choice(1,n-1); int b = source.choice(0,a-1); swap(t[a], t[b]); } return t; }; vector<string> bf = DBR<string>(f); sort(ALL(bf)); bf.erase(ALL(bf), bf.end()); }
割と素直に書ける気がしてきませんか……?
作るもの
変数を生成できるオブジェクト を受け取り、値を返す関数 f(source) を定義します。
関数 f では source を使って「呼び出しごとに変わる変数」を生成します。
この変数を使って値を作り、返します。これが全探索の内の1回の探索に相当します。
この関数 f を受け取って、DBRSource を渡しながら繰り返し呼び出して値を記録する関数 DBR を設計します。
作る
source から出力できるのはひとまず (l,r) を受け取って l以上r以下の数 を返すものだけにします。
最終的な実装は実は結構簡単で、こんな感じになります:
class DBRSource { public: int choice(int l, int r) { if (query_times >= (int)history.size()) { // new query Query q; q.type = QueryType_Choice; q.query_value[0] = l; q.query_value[1] = r; history.push_back(q); } auto& q = history[query_times++]; assert(q.type == QueryType_Choice && q.query_value[0] == l && q.query_value[1] == r); return q.query_value[0] + q.internal_value; } //private: enum QueryType { QueryType_Choice, QueryTypeNum }; struct Query { QueryType type = QueryType_Choice; int query_value[4]; int internal_value = 0; }; std::vector<Query> history; int query_times = 0; }; template<class X, typename F> std::vector<X> DBR(F f) { std::vector<X> ret; DBRSource source; source.history.clear(); source.query_times = 0; do { source.query_times = 0; ret.push_back(f(source)); while (source.history.size() > 0u) { auto& q = source.history.back(); q.internal_value++; bool is_end = false; switch (q.type) { case DBRSource::QueryType_Choice: if (q.internal_value > q.query_value[1] - q.query_value[0]) is_end = true; break; default: // expected error is_end = true; } if (!is_end) break; source.history.pop_back(); } } while (source.history.size() > 0); return ret; }
もうちょっとインターフェースとか整理したほうが良いんですが、それは追々……
基本的には DBRSource に問い合わせされた履歴を持って、前の問い合わせより1多い値を返す、というだけでOKです。
同じ問い合わせかどうかという判定は「何回目に呼ばれたか」で管理すれば良いです。
が、ARC168 Cの例のように、スワップ回数自体が可変の場合や、nC2選択のために値の範囲自体が可変の場合があるため、実際には返した値に応じて分岐する木構造を考える必要があります。
(木構造という観点では、これはDecision TreeをDefine-By-Runで構築している、とも考えられそうです)
ただ実際には木構造のあるパスにしか興味が無いので、結局「値を払い出し終わったクエリをpop_backする」という形でクエリの履歴を管理すれば十分です。
実践
ARC168 Cで検算に使いました (以下提出の95行目のnaive関数):
https://atcoder.jp/contests/arc168/submissions/47835839
展望
今回は一様乱数のようなインターフェースしか無いですが、以下のような拡張を作ると便利かもしれません:
- permutation
- combination
また、DBRSource を受け取って値を返す(DBRSource以外に副作用を産まない)関数を作ることで、その場で拡張することも容易そうです(Reactのカスタムフックみたいな考え方)。
まとめ
全探索してる時点で早解き的には負けなので、使わないに越したことは無いけど、たまにあれば便利そうですね
元ツイート
(現ポスト)
こういうのあったらたまに便利そうだなと思ったが、意外と簡単に作れた(Define-by-run) pic.twitter.com/PrZ9wGmcSB
— リッキー (@rickytheta) 2023年11月14日
平日の深夜2時に何してるんでしょうね