DQN CartPole例子算法改良
发布日期:2021-05-06 22:02:08 浏览次数:40 分类:原创文章

本文共 4467 字,大约阅读时间需要 14 分钟。

原帖地址:
通过增加奖励reward,在100轮左右就可以稳定坚持999了

# -*- coding: utf-8 -*-import randomimport gymimport numpy as npfrom collections import dequefrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import Adamfrom keras import backend as Kimport tensorflow as tfEPISODES = 5000class DQNAgent:    def __init__(self, state_size, action_size):        self.state_size = state_size        self.action_size = action_size        self.memory = deque(maxlen=2000)        self.gamma = 0.99    # discount rate        self.epsilon = 1.0  # exploration rate        self.epsilon_min = 0.001        self.epsilon_decay = 0.001        self.learning_rate = 0.001        self.model = self._build_model()        self.target_model = self._build_model()        self.update_target_model()    """Huber loss for Q Learning    References: https://en.wikipedia.org/wiki/Huber_loss                https://www.tensorflow.org/api_docs/python/tf/losses/huber_loss    """    def _huber_loss(self, y_true, y_pred, clip_delta=1.0):        error = y_true - y_pred        cond  = K.abs(error) <= clip_delta        squared_loss = 0.5 * K.square(error)        quadratic_loss = 0.5 * K.square(clip_delta) + clip_delta * (K.abs(error) - clip_delta)        return K.mean(tf.where(cond, squared_loss, quadratic_loss))    def _build_model(self):        # Neural Net for Deep-Q learning Model        model = Sequential()        model.add(Dense(24, input_dim=self.state_size, activation='relu'))        model.add(Dense(24, activation='relu'))        model.add(Dense(self.action_size, activation='linear'))        model.compile(loss=self._huber_loss,                      optimizer=Adam(lr=self.learning_rate))        return model    def update_target_model(self):        # copy weights from model to target_model        self.target_model.set_weights(self.model.get_weights())    def remember(self, state, action, reward, next_state, done):        self.memory.append((state, action, reward, next_state, done))    def act(self, state):        if np.random.rand() <= self.epsilon:            return random.randrange(self.action_size)        act_values = self.model.predict(state)        # print(act_values, state)        return np.argmax(act_values[0])  # returns action    def replay(self, batch_size):        minibatch = random.sample(self.memory, batch_size)        for state, action, reward, next_state, done in minibatch:            target = self.model.predict(state)            if done:                target[0][action] = reward            else:                # a = self.model.predict(next_state)[0]                t = self.target_model.predict(next_state)[0]                target[0][action] = reward + self.gamma * np.amax(t)                # target[0][action] = reward + self.gamma * t[np.argmax(a)]            self.model.fit(state, target, epochs=1, verbose=0)        if self.epsilon > self.epsilon_min:            self.epsilon -= self.epsilon_decay    def load(self, name):        self.model.load_weights(name)    def save(self, name):        self.model.save_weights(name)if __name__ == "__main__":    env = gym.make('CartPole-v1')    state_size = env.observation_space.shape[0]    action_size = env.action_space.n    agent = DQNAgent(state_size, action_size)    # agent.load("./save/cartpole-ddqn.h5")    done = False    batch_size = 32    for e in range(EPISODES):        state = env.reset()        state = np.reshape(state, [1, state_size])        for time in range(1000):            # env.render()            action = agent.act(state)            next_state, reward, done, _ = env.step(action)            # print(action, state_angency, next_state)            if not done:                reward = reward            elif time == 999:                reward = reward            else:                reward = reward - 10            if abs(next_state[0]) < 0.2:                reward += 1            if abs(next_state[2]) < 0.02:                reward += 1            next_state = np.reshape(next_state, [1, state_size])            agent.remember(state, action, reward, next_state, done)            state = next_state            if done:                agent.update_target_model()                print("episode: {}/{}, score: {}, e: {:.2}"                      .format(e, EPISODES, time, agent.epsilon))                break            if len(agent.memory) > batch_size:                agent.replay(batch_size)        # if e % 10 == 0:        #     agent.save("./save/cartpole-ddqn.h5")
上一篇:PyQt5 QTableView 全部item居中
下一篇:同时执行两个ubuntu的终端

发表评论

最新留言

初次前来,多多关照!
[***.217.46.12]2025年03月29日 22时45分36秒