강화 학습(Reinforcement Learning)

강화학습 Stochastic FrozenLake-v3 Low Pass Filter 방식 알고리듬 적용

coding art 2022. 1. 7. 18:46
728x90

파이선 코딩 초보자를 위한 텐서플로우∙OpenCV 머신 러닝 2차 개정판 발행

http://blog.daum.net/ejleep1/1175

 

파이선 코딩 초보자를 위한 텐서플로우∙OpenCV 머신 러닝 2차 개정판 (하이퍼링크) 목차 pdf 파일

본서는 이미 2021년 11월 초부터 POD코너에서 주문 구입이 가능합니다. 참고로 책 목차에 따른 내용별 학습을 위한 코드는 이미 대부분 다음(Daum)블로그에 보관되어 있으며 아래에서 클릭하면 해당

blog.daum.net

 

Deterministic에서 Stochastic 으로 조건을 바꿀 경우 다음과 같이 성공률이 아주 저조하게 얻어진다.

 

이러한 성공률을 높일 수 있도록 Low Pass Filtering 방식과 유사한 알고리듬을 개량해 보자.

action 에 Q 값에 랜덤 성분을 추가하되 에피소드 값이 커짐에 따라 랜덤성이 약해짐에 유의하자.

learning_rate 파라메터를 도입하여 다소 큰 값인 0.85를 설정한다.

Q 계산식에 Low Pass Filtering 방식을 적용하여 현재의 Q 값을 조금 반영하고 다음 단계의 Q 값을 많이 반영하되 그 합이 1.0이 된다. Low Pass Filter 라고 불리는 이 알고리듬은 Q 값이 시간에 따라 들쭉 날쭉 성이 심할 경우 약간 고주파적인 특성을 효율적으로 제거할 수 있다. 비슷한 개념이 스토캐스틱 경사항강법에서 사용하는 SGD(stochastic gradient ㅇdescent) 나 Adam 옵티마이저에서도 채요이 되고 있음에 유의하자.

이렇게 알고리듬을 변경하여 실행하면 다음과 같이 상당히 효율이 개선된 51% 대의 성공률이 얻어진다.

 

#LPF_stochastic_01.py

import gym
import numpy as np
import matplotlib.pyplot as plt
import random as pr
from gym.envs.registration import register

def rargmax(vector):
    m = np.amax(vector)
    indices = np.nonzero(vector == m)[0]
    return pr.choice(indices)        

register(
    id = 'FrozenLake-v3',
    entry_point = 'gym.envs.toy_text:FrozenLakeEnv',
    kwargs = {'map_name': '4x4',
              'is_slippery':True})

env = gym.make('FrozenLake-v1')
Q = np.zeros([env.observation_space.n,env.action_space.n])

learning_rate = 0.85
num_episodes = 2000
gamma = 0.99

rList = []
for i in range(num_episodes):
    state = env.reset()
    rAll = 0
    done = False
    
    while not done:
        action =  np.argmax(Q[state, :] + np.random.randn(1, env.action_space.n)/ (i+1))
        new_state, reward, done, _ = env.step(action)
        Q[state, action] = (1 - learning_rate) * Q[state, action] + learning_rate * (reward + gamma * np.max(Q[new_state, :]))
        rAll += reward
        state = new_state
    if i < 200:
        print(i, ":  ", rAll)
    rList.append(rAll)

#print("Sucess Rate: " + str(sum(rList)/num_episodes))
print("Sucess Rate: ", (sum(rList)/num_episodes))
print("Final Q-Table Values")
print("LEFT DOWN RIGHT UP")
print(Q)
plt.bar(range(len(rList)), rList, color="blue")
plt.show()