为了账号安全,请及时绑定邮箱和手机立即绑定

DQN基础:深度Q网络(DQN)完全教程【附代码】:CSDN博客

标签:
杂七杂八

概述

深度Q网络(DQN)是强化学习领域的一项关键算法,它结合了深度神经网络与传统的强化学习框架,旨在通过与环境的互动来自动学习决策策略。本文将以PyTorch为基础,指导读者在OpenAI Gym的CartPole任务集上训练一个DQN智能体。目标是让智能体通过左右移动小车以保持杆直立。通过设置环境、准备经验回放内存,并定义DQN模型,实现智能体在复杂环境中的自动化决策。最终,通过训练循环、智能体构建及训练过程展示,实现从环境初始化到智能体策略优化的完整流程,有效提升决策学习效率与性能。

强化学习(DQN)教程

任务描述

目标是让智能体控制一个具有四个状态(位置、速度等)的杆,通过左右移动小车保持杆直立。在每个时间步,小车会根据其动作移动,并接受正反馈奖励+1,直到杆倒下或小车超出中心位置,任务结束。

环境准备

首先,导入必要的库:

import gym
import numpy as np

env = gym.make('CartPole-v0').unwrapped

经验回放内存

DQN使用经验回放内存来学习,它存储智能体观察到的转换,并允许随机抽取样本进行训练。

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, state, action, next_state, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = (state, action, next_state, reward)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

模型定义

DQN模型使用卷积神经网络来处理当前帧和前一帧之间的差异。

class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        self.head = nn.Linear(32, 2)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.head(x)
        return x

训练DQN模型

接下来,实现训练循环、智能体以及训练过程:

def train_dqn(agent, env, num_episodes=100, batch_size=128, gamma=0.99):
    memory = ReplayMemory(10000)
    optimizer = optim.Adam(agent.parameters(), lr=0.001)
    eps = 0.9
    eps_decay = 0.995
    loss_fn = nn.MSELoss()

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        rewards = 0

        while not done:
            action = agent.select_action(state, eps)  # Select action with epsilon greedy policy
            next_state, reward, done, _ = env.step(action)
            memory.push(state, action, next_state, reward)
            state = next_state
            rewards += reward

            if len(memory) > batch_size:
                batch = memory.sample(batch_size)
                states, actions, next_states, rewards = zip(*batch)
                states = torch.stack(states).to(device)
                actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1).to(device)
                next_states = torch.stack(next_states).to(device)
                rewards = torch.tensor(rewards, dtype=torch.float).unsqueeze(1).to(device)

                pred_q_values = agent(states)
                target_q_values = pred_q_values.clone()
                next_q_values = agent(next_states)
                next_q_max = next_q_values.max(dim=1, keepdim=True)[0].detach()

                target_q_values = rewards + (1 - done) * gamma * next_q_max
                target_q_values = target_q_values.to(device)

                loss = loss_fn(pred_q_values.gather(1, actions), target_q_values)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        eps = max(eps * eps_decay, 0.01)  # Decay epsilon

        print(f"Episode {episode}: Total Reward = {rewards}, Epsilon = {eps}")

    return agent

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = DQN().to(device)
agent = train_dqn(agent, env, num_episodes=100)

通过上述代码,我们实现了从环境初始化、智能体的构建、到训练过程的完整流程。在这个过程中,我们利用了经验回放内存来提高训练效率,以及逐步减少探索以优化学习策略。最终,通过绘制训练过程中的奖励,我们可以观察智能体学习的进程和效果。

结论与展望

本文详细阐述了深度Q网络(DQN)的基本原理、实现步骤和训练流程,并通过OpenAI Gym的CartPole任务集展示了DQN在实际问题上的应用。DQN不仅在处理复杂决策任务时表现出色,而且在没有显式设计策略的情况下,通过与环境的交互自动学习到有效的策略。通过引入经验回放、延迟更新目标网络等技术,DQN有效地解决了传统Q-learning算法在处理长时程依赖和数据重复性问题上的局限性。此外,本文还提供了一段完整的代码示例,包括环境设置、模型定义、训练过程和评估方法,为读者提供了从理论到实践的完整学习路径。

通过本文的学习,读者不仅能够理解DQN的核心概念和工作原理,还能够动手实现自己的DQN智能体,并在不同的任务上进行实验和优化。随着强化学习领域的不断发展,DQN算法及其变种在游戏、机器人控制、自动驾驶等领域的应用将持续扩展,成为解决复杂决策问题的强大工具。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消