> 文档中心 > MATLAB手动实现DQN最短路径问题

MATLAB手动实现DQN最短路径问题


完整代码链接,点击打开下载即可

不用强化学习工具箱的DQN算法案例与matlab代码


本文建立在已经有DQN基础知识之上。

案例说明:

环境设置:这是一个30*30的矩阵迷宫,其中有两个状态obstacle(15,15),Goal(25,25),目标就是Agent如何不碰到障碍物可以到达Goal.
奖励设置:当Agent到达obstacle状态时reward=-1;当Agent到达Goal状态时reward=1;其他状态下reward=0.
状态设置:所在方块中x,y为状态;
动作设置:上,下,左,右。并且设置了随机性,当选动作上时,有80%概率选择上,10%概率选择左,10%概率选择右。
通过不断的学习使得Agent能够选择最优路径。

DQN与Q-learning区别

  1. 解决了状态空间太大问题;
  2. 利用神经网络逼近值函数代替Q-table,注意自己的数据是否一定要选择深度卷积网络,浅层网络能解决的就不用了深度神经网络了,本质就是神经网络逼近问题,本文用的最简单的神经网络BP去做的。
  3. 采用了经验回放,可以理解为我们用了一个容器去承载我们过去学习到经验,打破数据之间的关联。目前有一些文章研究了关于如何采样方面的方法,有需要的可以自行查阅,有时间会给大家分享我看到的文章。
    -------------话不多说我们上代码这里只提供一部分,全部的会上传到资源上)代码已上传,大家可以在此基础上直接修改变成自己的项目。
function DQNclose all;clear; clc;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%isTraining = true; %declare if it is training%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%addpath('../Environment');addpath('../Basic Functions');env = SAEnvironment;alpha = 0.1; %learning rate settings gamma = 0.9; %discount factormaxItr = 3000;%maximum iterations for ending one episodehidden_layer = [40 40];estimator = DQNEstimator(env,alpha,hidden_layer);if isTraining    %replay buffer    memory_size = 30000;    memory_cnt = 0;    batch_size = 3000;    memory_buffer(1:memory_size) = struct('state',[],'action',[],'next_state',[],'reward',[],'done',[]);     NUM_ITERATIONS = 20000; %change this value to set max iterations    epsilon = 0.8; %random action choice    min_epsilon = 0.3;    iterationCount(NUM_ITERATIONS) = 0;    rwd(NUM_ITERATIONS) = 0;else    NUM_ITERATIONS = 5;    epsilon = 0.3; %random action choice    load('DQN_weights.mat','-mat');    estimator.set_weights(Weights);endtimeStart = clock;for itr=1:NUM_ITERATIONS     env.reset([0 0]);      if ~isTraining env.reset(env.locA);   env.render();%display the moving environment    end countActions = 0;%count how many actions in one iteration      reward = 0;    done = false;    state = env.current_location; while ~done    if countActions == maxItr     break; end countActions = countActions + 1;   if ~isTraining     values = estimator.predict(state).out_value;     prob_a = make_epsilon_policy(values, epsilon);     action = randsample(env.actionSpace,1,true,prob_a);     [next_state, reward, done] = env.step(action);   state = next_state;     env.render();%display the moving environment     continue; end values = estimator.predict(state).out_value; prob_a = make_epsilon_policy(values, max(epsilon^log(itr),min_epsilon)); action = randsample(env.actionSpace,1,true,prob_a); [next_state, reward, done] = env.step(action);%  target = reward;%  if ~done%   target = reward + gamma*max(estimator.predict(next_state).out_value);%  end%  estimator.update(state,action,target); memory_buffer(2:memory_size) = memory_buffer(1:memory_size-1); memory_buffer(1).state = state; memory_buffer(1).action = action; memory_buffer(1).next_state = next_state; memory_buffer(1).reward = reward; memory_buffer(1).done = done; memory_cnt = memory_cnt + 1;   state = next_state;    end    fprintf('%d th iteration, %d actions taken, final reward is %d.\n',itr,countActions,reward);    if isTraining iterationCount(itr) = countActions; rwd(itr) = reward; %memory replay if memory_cnt >= memory_size     mini_batch = randsample(memory_buffer,batch_size);     for i=1:batch_size  tem_state = mini_batch(i).state;  tem_action = mini_batch(i).action;  tem_next_state = mini_batch(i).next_state;  tem_reward = mini_batch(i).reward;  tem_done = mini_batch(i).done;  tem_next_state_values = estimator.predict(tem_next_state).out_value;  tem_target = tem_reward;  if ~tem_done      tem_target = tem_reward + gamma*max(tem_next_state_values);  end  estimator.update(tem_state,tem_action,tem_target);     end end    endendif isTraining    timeEnd = clock;    timeDiff = sum([timeEnd - timeStart].*[0 0 0 3600 60 1]);    simulationTime = [timeStart timeEnd timeDiff];    save('DQN_simulationTime.mat','simulationTime');     Weights = estimator.weights;    save('DQN_weights.mat','Weights');    save('DQN_iterationCount.mat','iterationCount');    save('DQN_reward.mat','rwd');    figure,bar(iterationCount)    figure,bar(rwd)end

365PC电脑装机网