为了账号安全,请及时绑定邮箱和手机立即绑定

使用 TorchRL 进行强化学习 (PPO) 的教程 — PyTorch 教程 2.4.0+cu124 文档

标签:
杂七杂八

引言和准备工作

为了使用 TorchRL 进行强化学习,尤其是 PPO,您需要确保已安装必要的依赖库,并在 Google Colab 环境中配置好正确的设置。首先执行以下命令安装所需库:

!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm

在 Google Colab 中运行代码时,请确保环境已被正确配置。

超参数定义

定义一组用于训练的超参数,这将影响数据收集和策略优化的过程。以下是一些关键参数:

import torch

frames_per_batch = 1000
total_frames = 50_000
sub_batch_size = 64
num_epochs = 10
clip_epsilon = 0.2
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4
环境和转换器创建

使用 TorchRL 和 Gym 创建并配置环境。创造归一化、双精度转换和步数计数转换的环境:

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import Compose, DoubleToFloat, ObservationNorm, StepCounter, TransformedEnv, GymEnv
from torchrl.envs.utils import check_env_specs, set_exploration_type

env = GymEnv('CartPole-v1', device='cpu')
env = TransformedEnv(env, [DoubleToFloat(), StepCounter(), ObservationNorm()])
策略与值模型设计

设计策略网络和值模型,并利用 TorchRL 的类进行配置。这包括定义 Actor 和 Critic 模型:

class Actor(nn.Module):
    def __init__(self, num_cells):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(4, num_cells), nn.Tanh(),
            nn.Linear(num_cells, num_cells), nn.Tanh(),
            nn.Linear(num_cells, num_cells), nn.Tanh(),
            nn.Linear(num_cells, 2 * 1)
        )
        self.actor = ProbabilisticActor(
            module=self.net,
            spec=GymEnv('CartPole-v1').action_spec,
            in_keys=["loc", "scale"],
            distribution_class=TanhNormal,
            distribution_kwargs={"min": GymEnv('CartPole-v1').action_spec.space.low,
                                "max": GymEnv('CartPole-v1').action_spec.space.high},
            return_log_prob=True
        )

class Critic(nn.Module):
    def __init__(self, num_cells):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(4, num_cells), nn.Tanh(),
            nn.Linear(num_cells, num_cells), nn.Tanh(),
            nn.Linear(num_cells, num_cells), nn.Tanh(),
            nn.Linear(num_cells, 1)
        )
        self.critic = ValueOperator(
            module=self.net,
            in_keys=["observation"]
        )

actor = Actor(64)
critic = Critic(64)
数据收集器与重放缓冲区实现

创建数据收集器和重放缓冲区,以便在训练过程中使用,确保数据高效收集与存储:

replay_buffer = ReplayBuffer(storage=LazyTensorStorage(max_size=frames_per_batch),
                             sampler=SamplerWithoutReplacement())
PPO 损失函数与训练循环

使用 ClipPPOLoss 实现 PPO,并创建 GAE 模块、定义训练循环流程。这包括优化器设置和学习率调整策略:

advantage_module = GAE(gamma=gamma, lmbda=lmbda, value_network=critic, average_gae=True)
loss_module = ClipPPOLoss(
    actor_network=actor,
    critic_network=critic,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps
)
optimizer = torch.optim.Adam(loss_module.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_frames // frames_per_batch, 0.0)
训练循环

接下来将实现完整的训练循环,包括计算优势、从重放缓冲区采样、计算并优化损失,以及更新日志和进度条:

from tqdm import tqdm

logs = defaultdict(list)
pbar = tqdm(total=total_frames)

for _ in range(total_frames // frames_per_batch):
    for _ in range(num_epochs):
        # 计算优势
        advantage_module(tensordict_data)
        # 从重放缓冲区采样
        subdata = replay_buffer.sample(sub_batch_size)
        # 计算并优化损失
        loss_vals = loss_module(subdata.to(device))
        loss_value = loss_vals["loss_objective"] + loss_vals["loss_critic"] + loss_vals["loss_entropy"]
        loss_value.backward()
        torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
    # 更新日志和进度条
    # ...
总结与后续步骤

总结本教程的要点,并鼓励读者进一步学习和实验:

此教程提供了一个基础的框架,用于使用 TorchRL 进行 PPO 强化学习。实际应用中可能需要根据特定任务进行调整。借助 TorchRL 的特性和功能,您可以探索多种策略和环境,以解决各种强化学习问题。推荐进一步阅读 TorchRL 的官方文档和社区资源,以深入了解所有功能和最佳实践。在实践中,不断调整超参数、测试不同策略和环境的组合,可以优化模型的性能并解决更复杂的问题。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消