nashidos’s diary

アルゴリズムとか機械学習とか色々

Pythonでbit全探索を実装してみる-ABC079

bit全探索

全探索にもいろいろ種類がありますが今回はbit全探索をPythonで実装していきたいと思います。

bit全探索とはその名の通りbit演算を利用して行う全探索で、Yes or Noのような2択を網羅的に探索する時に使えます。説明を聞いてもわかりにくいかもしれませんが、実際に問題を解けば理解しやすいと思います。

問題

今回は以下の問題を扱います。
atcoder.jp


問題文

駅の待合室に座っているjoisinoお姉ちゃんは、切符を眺めています。
切符には 4つの 0 以上 9 以下の整数 A,B,C,Dが整理番号としてこの順に書かれています。
A op1 B op2 C op3 D = 7 となるように、op1,op2,op3に + か - を入れて式を作って下さい。
なお、答えが存在しない入力は与えられず、また答えが複数存在する場合はどれを出力してもよいものとします。

制約

0≦A,B,C,D≦9
入力は整数
答えが存在しない入力は与えられない

実装

冒頭で説明したようにbit全探索はYes or Noのような2択を網羅的に探索することができます。今回の問題では+ or -の2択を網羅的に探索します。

そこで今回は以下のように実装しました。

a,b,c,d = input()
for i in range(2**3):
    ls = ["+","+","+"]
    for j in range(len(ls)):
        if (i >> j) & 1:
            ls[j] = "-"
    if eval(a+ls[0]+b+ls[1]+c+ls[2]+d) == 7:
        print(a+ls[0]+b+ls[1]+c+ls[2]+d+"=7")
        break

今回は+,-の組み合わせをop1,op2,op3の3つで探索するので2^3通り探索します。

このコードの一番大事なところは (i >> j) & 1 の部分です。

これはビットi(2進数)をj回右にシフトして 1 と論理積を取る(最下位の桁が 1 かどうかを調べる)といったことをしています。こうすることで部分集合の全パターンを探索することができます。

実際に先ほどのコードを少し変えて正しくbit全探索できているか確認してみましょう。

a,b,c,d = input()
for i in range(2**3):
    ls = ["+","+","+"]
    for j in range(len(ls)):
        if (i >> j) & 1:
            ls[j] = "-"
    print(ls)

出力

['+', '+', '+']
['-', '+', '+']
['+', '-', '+']
['-', '-', '+']
['+', '+', '-']
['-', '+', '-']
['+', '-', '-']
['-', '-', '-']

正しく全パターンを列挙できていました。

bitをシフトするところが少し難しいですが、とりあえずbit全探索はYes or Noのような2択を網羅的に探索する必要がある場合に使用するということは覚えておきましょう。

ちなみに今回の問題は8通りだけ探索すればいいので全部場合分けすれば一応通ります。
例えば以下のようなひどいコードでも一応通ります(真似しないでね!)

import sys
s = (input())
a = int(s[0])
b = int(s[1])
c = int(s[2])
d = int(s[3])
ans = a+b+c+d
if ans == 7:
    print(s[0]+"+"+s[1]+"+"+s[2]+"+"+s[3]+"=7")
    sys.exit()
ans = a+b+c-d
if ans == 7:
    print(s[0]+"+"+s[1]+"+"+s[2]+"-"+s[3]+"=7")
    sys.exit()
ans = a+b-c+d
if ans == 7:
    print(s[0]+"+"+s[1]+"-"+s[2]+"+"+s[3]+"=7")
    sys.exit()
ans = a-b+c+d
if ans == 7:
    print(s[0]+"-"+s[1]+"+"+s[2]+"+"+s[3]+"=7")
    sys.exit()
ans = a-b-c+d
if ans == 7:
    print(s[0]+"-"+s[1]+"-"+s[2]+"+"+s[3]+"=7")
    sys.exit()
ans = a+b-c-d
if ans == 7:
    print(s[0]+"+"+s[1]+"-"+s[2]+"-"+s[3]+"=7")
    sys.exit()
ans = a-b+c-d
if ans == 7:
    print(s[0]+"-"+s[1]+"+"+s[2]+"-"+s[3]+"=7")
    sys.exit()
ans = a-b-c-d
if ans == 7:
    print(s[0]+"-"+s[1]+"-"+s[2]+"-"+s[3]+"=7")
    sys.exit()

いやぁ、ひどいコードですね。ちゃんとbit全探索を使って解きましょうね。

おわり