Deep learning Pytork,PPO与LSTM不兼容';用情节轨迹训练

Deep learning Pytork,PPO与LSTM不兼容';用情节轨迹训练,deep-learning,pytorch,lstm,reinforcement-learning,Deep Learning,Pytorch,Lstm,Reinforcement Learning,我在学习强化学习。 这段代码会导致一些问题。 当它运行train()函数时,其策略和值函数输出均为Nan。 下面是我运行一个简单环境Cartpole的代码 这是导入设置 import gym import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.distributions import Categorical 这是超参数 learning

我在学习强化学习。 这段代码会导致一些问题。 当它运行train()函数时,其策略和值函数输出均为Nan。 下面是我运行一个简单环境Cartpole的代码

这是导入设置

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
这是超参数

learning_rate = 0.0005
gamma = 0.98
lmbda = 0.95
K = 3
epsilon = 0.1
这是用于PPO和LSTM的代理代码

class PPO_lstm(nn.Module):
  def __init__(self):
    super(PPO_lstm, self).__init__()
    self.log_list = []

    self.fc1 = nn.Linear(4, 32) # state encoder, value and policy function use it
    self.lstm = nn.LSTM(32, 64) # LSTM, value and policy function use it
    self.fc2_pi = nn.Linear(64, 2) # policy function 
    self.fc2_v = nn.Linear(64, 1) # value function
    self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

  def pi(self, x, hidden_state):  # get policy
    x = F.relu(self.fc1(x)) # x size : [32] or [batch, 32]
    x = x.view(1, -1, 32) # x size : [1, batch, 32]
    x, new_hidden = self.lstm(x, hidden_state) # x size : [1, batch, 64]
    x = x.squeeze(0) # x size : [batch, 64]
    pi = F.log_softmax(self.fc2_pi(x), dim=1)
    return pi, new_hidden # pi size : [batch, 2], new_hidden size : [1, batch, 64]

  def v(self, x, hidden_state):  # get state value
    x = F.relu(self.fc1(x)) # x size : [32] or [batch, 32]
    x = x.view(1, -1, 32) # x size : [1, batch, 32]
    x, new_hidden = self.lstm(x, hidden_state) # x size : [1, batch, 64]
    x = x.squeeze(0) # x size : [batch, 64]
    v = self.fc2_v(x) # v size : [batch, 1]
    return v # v size : [batch, 1]

  def add_log(self, log):  # add transition which used train function
    self.log_list.append(log) # s, a, r, s', action_prob, done, hidden(h, c), hidden_prime(h, c)

  def make_batch(self): # make batch file using log_list for training
    s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst, h_h_lst, h_c_lst, h_prime_h_lst, h_prime_c_lst = [], [], [], [], [], [], [], [], [], []
    for i in self.log_list:
      s, a, r, s_prime, done, action_prob, h, h_prime = i

      s_lst.append(s) # s_lst size : [batch, 4]
      a_lst.append([a]) # a_lst size : [batch, 1]
      r_lst.append([r]) # r_lst size : [batch, 1]
      s_prime_lst.append(s_prime) # s_prime_lst size : [batch, 4]
      prob_a_lst.append([action_prob]) # prob_a_lst size : [batch, 1]
      done_lst.append([0 if done else 1]) # done_lst size : [batch, 1]
      h_h, h_c = h # hidden_h, hidden_c size : [1, 1, 64]
      h_h_lst.append(h_h.detach().numpy()) # hidden1_c_lst size : [batch] inside : [1, 1, 64]
      h_c_lst.append(h_c.detach().numpy()) # hidden1_h_lst size : [batch] inside : [1, 1, 64]
      h_prime_h, h_prime_c = h_prime
      h_prime_h_lst.append(h_prime_h.detach().numpy()) # hidden2_c_lst size : [batch] inside : [1, 1, 64]
      h_prime_c_lst.append(h_prime_c.detach().numpy()) # hidden2_h_lst size : [batch] inside : [1, 1, 64]

    s_lst = torch.tensor(s_lst, dtype=torch.float)    
    a_lst = torch.tensor(a_lst)    
    r_lst = torch.tensor(r_lst)    
    s_prime_lst = torch.tensor(s_prime_lst, dtype=torch.float)    
    prob_a_lst = torch.tensor(prob_a_lst, dtype=torch.float)    
    done_lst = torch.tensor(done_lst)
    h_h_lst = torch.tensor(h_h_lst, dtype=torch.float).squeeze().unsqueeze(0)
    h_c_lst = torch.tensor(h_c_lst, dtype=torch.float).squeeze().unsqueeze(0)
    h_prime_h_lst = torch.tensor(h_prime_h_lst, dtype=torch.float).squeeze().unsqueeze(0)
    h_prime_c_lst = torch.tensor(h_prime_c_lst, dtype=torch.float).squeeze().unsqueeze(0) # size : [1, batch, 64]

    self.log_list = []
    return s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst, h_h_lst, h_c_lst, h_prime_h_lst, h_prime_c_lst

这是使用批处理(s、a、r、s',完成、隐藏、隐藏)的训练函数

它的主要功能是运行openAI健身房cartpole环境

def main():
  env = gym.make('CartPole-v1')
  agent = PPO_lstm()
  score = 0.0
  print_interval = 20

  for episode in range(10000):
    hidden = (torch.zeros([1, 1, 64], dtype=torch.float), torch.zeros([1, 1, 64], dtype=torch.float)) # lstm init hidden state
    s = env.reset()
    done = False
    while not done:
      for step in range(20):
        pi, hidden_prime = agent.pi(torch.from_numpy(s).float(), hidden) # pi size : [1, 2]
        pi = pi.squeeze() # pi size : [2]
        m = Categorical(pi)
        action = m.sample().item()
        s_prime, reward, done, info = env.step(action)
        agent.add_log((s, action, reward / 100.0, s_prime, pi[action].item(), done, hidden, hidden_prime))
        # hidden dim : ([1, 1, 64], [1, 1, 64])
        hidden = hidden_prime
        s = s_prime
        score += reward
        if done:
          break
      agent.train()
    if episode % print_interval == 0 and episode != 0:
      print("episode {}'s avg score : {}".format(episode, score/print_interval))
      score = 0.0

  env.close()

if __name__ == '__main__':
  main()

当我在没有LSTM的情况下(使用MLP策略和值函数)尝试这段代码时,它运行得很好。
您能告诉我使用LSTM或隐藏状态时有什么问题吗?

您成功了吗?
def main():
  env = gym.make('CartPole-v1')
  agent = PPO_lstm()
  score = 0.0
  print_interval = 20

  for episode in range(10000):
    hidden = (torch.zeros([1, 1, 64], dtype=torch.float), torch.zeros([1, 1, 64], dtype=torch.float)) # lstm init hidden state
    s = env.reset()
    done = False
    while not done:
      for step in range(20):
        pi, hidden_prime = agent.pi(torch.from_numpy(s).float(), hidden) # pi size : [1, 2]
        pi = pi.squeeze() # pi size : [2]
        m = Categorical(pi)
        action = m.sample().item()
        s_prime, reward, done, info = env.step(action)
        agent.add_log((s, action, reward / 100.0, s_prime, pi[action].item(), done, hidden, hidden_prime))
        # hidden dim : ([1, 1, 64], [1, 1, 64])
        hidden = hidden_prime
        s = s_prime
        score += reward
        if done:
          break
      agent.train()
    if episode % print_interval == 0 and episode != 0:
      print("episode {}'s avg score : {}".format(episode, score/print_interval))
      score = 0.0

  env.close()

if __name__ == '__main__':
  main()