12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP
13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP
24 using namespace mlpack::ann;
29 template <
typename NetworkType = FFN<MeanSquaredError<>,
30 GaussianInitialization>>
53 const bool isNoisy =
false):
57 network.Add(
new Linear<>(inputDim, h1));
61 noisyLayerIndex.push_back(network.Model().size());
64 noisyLayerIndex.push_back(network.Model().size());
71 network.Add(
new Linear<>(h2, outputDim));
75 SimpleDQN(NetworkType network,
const bool isNoisy =
false):
76 network(std::move(network)),
91 void Predict(
const arma::mat state, arma::mat& actionValue)
93 network.Predict(state, actionValue);
102 void Forward(
const arma::mat state, arma::mat& target)
104 network.Forward(state, target);
112 network.ResetParameters();
120 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
122 boost::get<NoisyLinear<>*>
123 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
128 const arma::mat&
Parameters()
const {
return network.Parameters(); }
139 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
141 network.Backward(state, target, gradient);
152 std::vector<size_t> noisyLayerIndex;
void ResetParameters()
Resets the parameters of the network.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
SimpleDQN()
Default constructor.
arma::mat & Parameters()
Modify the Parameters.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
Implementation of the base layer.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Implementation of the NoisyLinear layer class.
const arma::mat & Parameters() const
Return the Parameters.
SimpleDQN(NetworkType network, const bool isNoisy=false)
SimpleDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false)
Construct an instance of SimpleDQN class.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
The mean squared error performance function measures the network's performance according to the mean ...
This class is used to initialize weigth matrix with a gaussian.