読者です 読者をやめる 読者になる 読者になる

のんびりしているエンジニアの日記

ソフトウェアなどのエンジニア的な何かを書きます。

Bandit Problemと強化学習ーこれであなたも大金持ち?ー

Sponsored Links

皆さんこんにちは
お元気ですか。私は元気です。
本日はBandit Problemと呼ばれる問題を強化学習で解いてみます。

Bandit Problemについて

Bandit Problem(和名:バンディット問題)は
当たる確率の異なるスロットマシンから最も大きい報酬を得るには
どうすればよいか?といった問題です。

以下のようなスロットがあったとします。
f:id:tereka:20160703183327p:plain

しかし、実はスロット達、あたる当たる確率が異なるスロットなのです。
そのようなスロットの中で最も報酬を高くするようスロットを選んでいくにはどうすればよいかといった問題を
解くことができます。
つまり、どうすれば大金持ちになれるかわかるということです、もちろん儲かる保証はしません。

因みに一般的な問題としては、A/Bテストに使われているとかいないとか。

解き方

いくつか解法がありますが、今回は3点紹介します。

  1. epsilon greedy algorithm
  2. Softmax
  3. UCB

今回解いた問題

今回は強引に以下の割合で当たる問題を解いてみました。
0.5, 0.3, 0.7, 0.2, 0.9, 0.2, 0.3, 0.2, 0.3, 0.2, 0.4

つまり、最も良いのは0.9でこれを引き続けることが理論上、最も良い作業となります。

epsilon greedy algorithm

epsilon greedy algorithmは一定の確率で、適当に選ぶ振る舞いをし、
基本的に最も期待が高い報酬を選択します。
ある意味直感的にわかりやすい

一定回数実施すると、最も高い報酬のところを選択するよう収束するように動作します。
因みにグラフも書いてみた。
以下のグラフは縦軸が報酬の平均、横軸がトライ回数です。

f:id:tereka:20160703181902p:plain

基本的に報酬は最大に収束するように動作しますが、最悪の行動を選択した場合において
残り続けてしまうため収束が遅くなる可能性があります。

Softmax Tempature

温度と呼ばれる概念を与えたBandit Problemの解き方。
epsilon greedy algorithmでは、ランダムな選択時に全ての行動の選択が等しくなる欠点があり、
その欠点を解消するように挙動する。

基本的にはSoftmax関数であるが、温度(T)と呼ばれる概念が投入されており、
この温度関数が高いと、ランダムに振る舞おうとする傾向が強くなるそうです。
(※{X_i} = i番目の期待値、{P_i} = i番目を選択する確率)
しかし、この温度と呼ばれる概念

{ \displaystyle
P_i = \frac{e^{X_i/T}}{\sum{e^{X_i/T}}}
}

グラフ
f:id:tereka:20160703181906p:plain

UCB

UCBと呼ばれる値を導入した
UCB関数は、選択の頻度が低い値を選ぶよう挙動し、
選択の頻度の低い値を選択し、推定する。式は以下の式であり、Cは定数
nは総プレイ回数で、{n_i}はindexのiを選択した回数

{ \displaystyle
UCB_i = X_i + C\sqrt{\frac{lnn}{n_i}}
}

実際にUCB関数を使ってBandit Problemを解いたグラフが以下

f:id:tereka:20160703181908p:plain

感想

実は何回かepsilon greedyを実行していますが、結構ブレブレな挙動を示す傾向ですね。
かなり乱数に左右されるアルゴリズムであることがわかる。。。
Bandit Problemで実際のスロットやるとどうなるのか気になるけど怖くてできない。

参考文献

https://www.jstage.jst.go.jp/article/fss/30/0/30_174/_pdf(強化学習におけるUCB行動選択手法の効果)

ソースコード

では、最後にソースコードを書いてみました。
基本的には選択手法とパラメータ以外は同じです。


import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns


class Bandit(object):
    def __init__(self, prob):
        """
        Bandit
        :param prob: probability of 1 reward
        :return: None
        """
        self.prob = prob

    def get_reward(self):
        """
        calculate reward
        :return: reward
        """
        return np.random.binomial(1, self.prob)


class BanditProblem(object):
    def __init__(self, probs):
        """
        Bandit problem class
        :param probs: list of probability of 1 reward
        :return: None
        """
        self.number = len(probs)
        self.bandits = [Bandit(prob) for prob in probs]


class BanditSolver(object):
    def __init__(self):
        pass

    def select_action(self, scores, counter):
        raise NotImplementedError()

    def solve(self, bandits, times=1000):
        """
        solve the bandit problem
        :param bandits: bandits
        :param times: times(e_greedy estimation)
        :return: None
        """
        total_rewards = 0.0
        count = 0
        average_score = []
        selected_indexs = []
        scores = [0.0 for number in xrange(bandits.number)]
        counter = [0.0 for number in xrange(bandits.number)]

        for time in xrange(times):
            selected_index = self.select_action(scores, counter)

            score = bandits.bandits[selected_index].get_reward()
            counter[selected_index] += 1
            scores[selected_index] = (scores[selected_index] * (counter[selected_index] - 1) + score) / counter[
                selected_index]
            total_rewards += score
            selected_indexs.append(selected_index)

            count += 1
            average_score.append(total_rewards / count)
        return average_score, counter, scores


class AverageEgreedySolver(BanditSolver):
    def __init__(self, e_greedy):
        """
        Average Bandit Problem Solver
        :param e_greedy: probability of random
        :return: none
        """
        self.e_greedy = e_greedy
        super(AverageEgreedySolver, self).__init__()

    def select_action(self, scores, counter):
        action = np.random.binomial(1, 1 - self.e_greedy)
        if action == 1:
            selected_index = np.argmax(scores)
        else:
            selected_index = np.random.randint(len(scores))
        return selected_index


class SoftmaxBanditSolver(BanditSolver):
    def __init__(self, temperature):
        self.temperature = temperature
        super(SoftmaxBanditSolver, self).__init__()

    def select_action(self, scores, counter):
        scores = np.exp(np.array(scores) / self.temperature)
        denomminator_score = np.sum(scores)
        select_prob = scores / denomminator_score
        cumsum_prob = select_prob.cumsum()
        rand = np.random.random()
        for index,prob in enumerate(cumsum_prob):
            if rand < prob:
                return index
        return None


class UCBSolver(BanditSolver):
    def __init__(self, C):
        """
        Parameter C
        :param C: UCB Solver
        :return:
        """
        self.C = C
        super(UCBSolver, self).__init__()

    def select_action(self, scores, counter):
        return np.argmax(scores + self.C * np.log(np.sum(counter) + 0.0000001) / np.array(counter))


def solve_bandit_problem(solver_name,parameters,parameter_name):
    """
    Bandit problem solver
    :param solver_name: solver name
    :param parameters: parameters
    :param parameter_name: parameter name
    :return: None
    """
    for parameter in parameters:
        probs = [0.5, 0.3, 0.7, 0.2, 0.9, 0.2, 0.3, 0.2, 0.3, 0.2, 0.4]
        bandit = BanditProblem(probs)
        solver = None
        if solver_name == "average_bandit":
            solver = AverageEgreedySolver(parameter)
        elif solver_name == "softmax_bandit":
            solver = SoftmaxBanditSolver(parameter)
        elif solver_name == "ucb_bandit":
            solver = UCBSolver(parameter)
        average_score, counter, scores = solver.solve(bandit, times=1000)

        plt.plot(range(len(average_score)), average_score)

    plt.legend(["{} {}".format(parameter_name, parameter) for parameter in parameters])
    plt.savefig("{}.png".format(solver_name))
    plt.clf()

if __name__ == '__main__':
    solve_bandit_problem("average_bandit",[0.01,0.05,0.1],"epsilon")
    solve_bandit_problem("softmax_bandit",[0.01,0.05,0.1,0.2,1.0,2.0],"temperature")
    solve_bandit_problem("ucb_bandit",[0.01,0.05,0.1,0.2,1.0,2.0],"C")
広告を非表示にする