저번 포스트에서 말했던 Tic-Tac-Toe 게임을 하는 바닐라 MCTS를 Python을 통해 구현해 보았다.
Tic-Tac-Toe란?
Tic-Tac-Toe란, 오목과 비슷한 서양 게임이다. 3x3의 판에서 3개를 가로/세로/대각선 방향으로 연속 3개를 먼저 놓는 사람이 이긴다.
굉장히 간단한 게임으로, 양 쪽 플레이어가 합리적인 판단만을 한다면 항상 무승부만 나오는 게임이고, "올바른" 수가 모든 상황에서 항상 존재하기 때문에 MCTS 알고리즘의 검증용으로는 좋을 것 같다.
후술할 설명을 더 간결하게 하기 위해, Tic-Tac-Toe 보드의 각 자리를 다음과 같이 숫자를 통해 설명하겠다.
[ 1 2 3 ]
[ 4 5 6 ]
[ 7 8 9 ]
먼저, 선공을 잡은 플레이어는 항상 5번 자리에 놓는 것이 후공의 이기는 수를 제한하면서 본인의 이기는 수를 극대화하기 때문에 유리하다.
[ _ _ _ ]
[ _ O _ ]
[ _ _ _ ]
후공 플레이어는 이때 2, 4, 6, 8번 자리에 놓으면 선공 플레이어에게 100% 이기는 수를 제공하기 때문에, 항상 1, 3, 7, 9번 자리에 두어야 한다.
[ X _ _ ]
[ _ O _ ]
[ _ _ _ ]
이 보드로부터는 선공 플레이어는 최대한 공격의 수를 만들고, 후공 플레이어는 그 수를 최대한 막으면 무승부로 게임이 끝난다.
MCTS 코드
그러면 구현한 MCTS 코드를 공유하겠다.
import time
import numpy as np
import random
import copy
class MCTS_Node:
def __init__(self, playerid, state, parent=None):
self._state = state
self._parent = parent
self._children = []
self._untried_actions = []
self._number_of_visits = 0
self._playerid = playerid
self._is_fully_expanded = False
self._value = {-1: 0, 0: 0, 1: 0}
self._board = self._state.get_board()
def get_state(self):
return self._state
def get_parent(self):
return self._parent
def get_value(self):
return self._value[self._playerid] - self._value[-self._playerid]
def get_untried_action(self): # essentially selection stage
if not self._untried_actions and not self._children:
self._untried_actions = self._state.get_legal_actions()
if len(self._untried_actions) == 1:
self._is_fully_expanded = True
return self._untried_actions.pop()
def expansion(self):
next_action = self.get_untried_action()
next_state = self._state.get_next_state(self._playerid, next_action) # this is where current player makes his next move
next_child_node = MCTS_Node(playerid=self._playerid * -1, state=next_state, parent=self)
self._children.append(next_child_node)
return next_child_node
def rollout(self):
return self._state.play_out(self._playerid)
def backpropagate(self, result):
self._number_of_visits += 1
self._value[result] += 1
if self._parent is not None:
self._parent.backpropagate(result)
def is_terminal_node(self):
return True if self._state.is_state_terminated() else False
def is_fully_expanded(self):
return self._is_fully_expanded
def num_of_visits(self):
return self._number_of_visits
def best_child(self, param_exploration):
choices_weights = [
(c.get_value() / c.num_of_visits()) + param_exploration * np.sqrt((2 * np.log(self._number_of_visits) / c.num_of_visits()))
for c in self._children
]
return self._children[np.argmax(choices_weights)]
def best_child_weights(self, param_exploration):
choices_weights = [
(c.get_value() / c.num_of_visits()) + param_exploration * np.sqrt((2 * np.log(self._number_of_visits) / c.num_of_visits()))
for c in self._children
]
return choices_weights
class MCTS_TicTacToe:
def __init__(self, board_length: int, win_length: int):
self._board_length = board_length
self._win_condition = win_length # for simplicity
# self._win_condition = board_length
self._board = self.create_gameboard(self._board_length)
self._last_moved = -1
self._winner = None
def create_gameboard(self, len_board: int):
# O is 1, X is -1, empty is 0
return np.zeros((len_board, len_board), dtype=np.int8)
def get_legal_actions(self, board=None):
"""Returns numpy array of possible moves"""
if board is None:
board = self._board
legal_actions = [array for array in np.argwhere(board == 0)]
random.shuffle(legal_actions)
return legal_actions
def is_state_terminated(self, board=None):
"""Check if current game state warrants termination"""
if board is None:
board = self._board
if (board == 0).sum() == 0:
self._winner = 0
return True
for player_id in [-1, 1]:
for k in range(self._board_length - self._win_condition):
for j in range(self._board_length - self._win_condition):
sub_board = self._board[k:k+self._win_condition, j:j+self._win_condition]
# Check horizontal, vertical
for i in range(self._win_condition):
if np.all(sub_board[:, i] == player_id):
self._winner = player_id
return True
if np.all(sub_board[i, :] == player_id):
self._winner = player_id
return True
# Check diagonals
if np.all(np.diagonal(sub_board) == player_id) or np.all(np.fliplr(sub_board).diagonal() == player_id):
self._winner = player_id
return True
return False
def play_out(self, playerid):
"""Returns final state after playing out randomly"""
board = self._board.copy()
last_moved = playerid
while not self.is_state_terminated(board):
moves = self.get_legal_actions(board)
if not moves:
winner = 0 # Tie
return winner
next_move = random.choice(moves)
board[next_move[0], next_move[1]] = playerid
last_moved = playerid
playerid *= -1
winner = last_moved
return winner
def check_winner(self, playerid):
if self._winner == playerid:
return 1
elif self._winner == playerid * -1:
return -1
else:
return 0
def get_winner(self):
return self._winner
def get_board(self):
return self._board
def print_board(self):
blen = self._board_length
line = "-" * (3 * blen + 2)
symbols = {1: "O ", 0: " ", -1: "X "}
for i in range(blen):
curline = "[ " + " ".join([symbols[self._board[i, j]] for j in range(blen)]) + "]"
print(curline)
print(line)
def update_board(self, playerid: int, move: tuple):
self._board[move[0], move[1]] = playerid
def get_next_state(self, pid, next_action):
new_state = copy.deepcopy(self)
new_state.update_board(pid, next_action)
return new_state
class MCTS:
def __init__(self, root_node):
self.root = root_node
def find_best_action(self, num_seconds=None, num_tries=None):
tries = 0
end_time = time.time() + num_seconds
if num_tries is None:
assert(num_seconds is not None)
while True:
next_node = self.policy()
reward = next_node.rollout()
next_node.backpropagate(reward)
tries += 1
if time.time() > end_time:
break
best_child_weights = self.root.best_child_weights(param_exploration=0.)
print(tries)
return self.root.best_child(param_exploration=0.)
else:
while tries < num_tries:
next_node = self.policy()
reward = next_node.rollout()
next_node.backpropagate(reward)
tries += 1
best_child_weights = self.root.best_child_weights(param_exploration=0.)
print(f"{time.time() - end_time: .4f} seconds")
return self.root.best_child(param_exploration=0.)
def policy(self):
current_node = self.root
while not current_node.is_terminal_node():
if not current_node.is_fully_expanded():
return current_node.expansion()
else:
current_node = current_node.best_child(param_exploration=1.4)
return current_node
if __name__ == "__main__":
pid = 1
state = MCTS_TicTacToe(board_length=4, win_length=3)
state.print_board()
while not state.is_state_terminated():
player = MCTS_Node(playerid=pid, state=state)
MCTS_set = MCTS(player)
best_child = MCTS_set.find_best_action(num_seconds=20, num_tries=30000)
state = best_child.get_state()
# state.update_board(pid, best_child)
state.print_board()
pid *= -1
dict_winner = {-1: "Player 2", 0: "A Tie", 1: "Player 1"}
print(f"Winner is: {dict_winner[state.get_winner()]}!")
저번 포스팅에서 서술한 MCTS의 모든 단계가 담겨있다.
- 먼저 state, node, 그리고 MCTS 알고리즘을 정의한다
- policy을 통해 현 노드의 child 노드 중 하나를 선택한다 (1-selection). [Line 193]
- 2번에서 선택한 노드의 child를 모두 선택해본 적 있다면, 해당 노드의 best child를 UCT 알고리즘을 통해(q-value가 높고, 탐험상수/방문 횟수로 exploration을 보정) 다시 선택한다. 단, 최적의 수를 선택할 때에는 q-value만으로 선택한다. [Line 219]
- 2번에서 선택한 노드에서 가능한 action 중 하나를 무작위로 선택한다 (2-expand) [Line 217]
- 3번에서 선택된 action에 대해 무작위 수를 두며 게임이 끝날 때까지 플레이한다(3-simulation)/ [Line 194]
- 4번에서 simulation 결과로, 2번에서 선택한 노드 및 그 노드의 부모 노드에 대해 backpropagation을 한다(4-backprop).
- 이때 backpropagation은 ML의 DNN과 같이 weight에 대한 update가 아닌, Q-value를 증감시키는 것을 말한다
- 주어진 시간을 넘기거나, 주어진 수만큼 탐색을 했을 경우, 탐색한 수 중 가장 Q-value가 높은 수를 둔다.
- 1-6번을 게임이 끝날 때까지 반복한다.
MCTS를 구현하기 위해서는 다음 클래스 및 함수가 필요하다. 위에 추가로 구현한 것도 있지만, 서술하는 함수/클래스는 MCTS 실행에 있어 100% 필수적인 것들이다.
- Node class
- state, parent, children, untried_actions, num_visits, is_fully_expanded, q_value, playerid(선택사항)
- get_q_value()
- get_untried_actions()
- get_number_of_visits()
- expand()
- rollout()
- backpropagate()
- is_terminal_node()
- is_fully_expanded()
- best_child()
- State class
- create_board()
- update_board(next_move)
- get_legal_actions()
- get_next_state(next_action)
- is_terminal_state()
- play_out() (rollout에 사용)
- check_winner()
- print_board() (없어도 되지만 state에 넣는게 사용하기 편하다)
- Game class
- policy() (다음 node 선택하는 함수)
- find_best_action() (selection부터 backpropagation까지 전체적인 MCTS 기법을 실행하는 함수)
state를 생성하기도, termination state를 정의하기도 간단하여 MCTS 알고리즘을 구현하기 매우 편했던 게임인 것 같다.
대신, 직접 돌려보면 알겠지만 board length가 조금만 늘어나도 최적의 수를 구하는데 시간이 오래 걸리기 시작한다.
이는 현 알고리즘 상으로는 탐험상수를 조정하여 개선할 수 있을 것 같지만, 이 또한 보드의 크기가 늘어날수록 한계가 있을 것으로 보인다.
이를 개선하기 위해서는 현 state에 대해, 기존에 학습한 최적의 수를 바로 도출할 수 있는 ML/DL 기법이 가장 최선일 것으로 보인다.
때문에, 다음 포스트에서는 알파고에서 사용된 ML/DL과 MCTS의 접목 방식을 보고, 해당 implementation을 위 코드에 적용하는 것을 목표로 하겠다.
이번에 구현하면서 느낀 것이지만, 확실히 이론적으로 배우는 것과 실제 구현을 통해 이론을 체득하는 것의 효용은 크게 다르다. 이 글을 읽는 여러분도 직접 간단한 게임에 대해 MCTS를 구현하고, 디버깅을 할 때 위 코드를 참고하는 것을 추천한다.
필자가 코딩하면서 가장 막혔던 구간은 selection과 backpropagation 및 q-value의 설정인데, 이는 직접 한번 생각대로 해보고, 문제가 생길 시 참고하는 것이 좋을 것 같다.