00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef CABSTRACTRILEARNER_H
00033 #define CABSTRACTRILEARNER_H
00034
00035 #include "cerrorlistener.h"
00036 #include "cagentlistener.h"
00037 #include "cbaseobjects.h"
00038
00039 class CAgentController;
00040 class CDeterministicController;
00041 class CAbstractQFunction;
00042 class CAbstractQETraces;
00043 class CActionDataSet;
00044 class CGradientQFunction;
00045 class CGradientQETraces;
00046 class CResidualFunction;
00047 class CResidualGradientFunction;
00048 class CFeatureList;
00049 class CAbstractBetaCalculator;
00050 class CFeatureQFunction;
00051
00053
00085 class CTDLearner : public CSemiMDPRewardListener, public CErrorSender
00086 {
00087 protected:
00088
00090 bool externETraces;
00091
00093 CAgentController *estimationPolicy;
00094
00096 CAction *lastEstimatedAction;
00097
00098 CAbstractQFunction *qfunction;
00099
00100 CAbstractQETraces *etraces;
00101
00102 CActionDataSet *actionDataSet;
00103
00105
00112 virtual void learnStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00113
00115 virtual double getTemporalDifference(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00116
00118 virtual double getResidual(double oldQ, double reward, int duration, double newQ);
00119
00121 virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *action);
00122
00123 public:
00125 CTDLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction, CAbstractQETraces *etraces, CAgentController *estimationPolicy);
00127
00130 CTDLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction, CAgentController *estimationPolicy);
00131
00132 virtual ~CTDLearner();
00133
00134 virtual void loadValues(char *filename);
00135 virtual void saveValues(char *filename);
00136
00137 virtual void loadValues(FILE *stream);
00138 virtual void saveValues(FILE *stream);
00139
00141 virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00143
00147 virtual void intermediateStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00148
00150 virtual void newEpisode();
00151
00153
00155 void setAlpha(double alpha);
00157 void setLambda(double lambda);
00158
00159 CAgentController* getEstimationPolicy();
00160 void setEstimationPolicy(CAgentController * estimationPolicy);
00161
00162 CAbstractQFunction* getQFunction();
00163
00164 CAbstractQETraces *getETraces();
00165 };
00166
00168
00174 class CQLearner : public CTDLearner
00175 {
00176 public:
00177 CQLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction);
00178 ~CQLearner();
00179 };
00180
00182
00191 class CSarsaLearner : public CTDLearner
00192 {
00193 public:
00194 CSarsaLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction, CDeterministicController *agent);
00195 ~CSarsaLearner();
00196 };
00197
00198
00199 class CTDGradientLearner : public CTDLearner
00200 {
00201 protected:
00202 CResidualFunction *residual;
00203 CResidualGradientFunction *residualGradient;
00204 CGradientQFunction *gradientQFunction;
00205 CGradientQETraces *gradientQETraces;
00206
00207 CFeatureList *oldGradient;
00208 CFeatureList *newGradient;
00209 CFeatureList *residualGradientFeatures;
00210
00211 virtual double getResidual(double oldQ, double reward, int duration, double newQ);
00212 virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *action);
00213
00214 public:
00215 CTDGradientLearner(CRewardFunction *rewardFunction, CGradientQFunction *qfunction, CAgentController *agent, CResidualFunction *residual, CResidualGradientFunction *residualGradient);
00216
00217 ~CTDGradientLearner();
00218 };
00219
00220 class CTDResidualLearner : public CTDGradientLearner
00221 {
00222 protected:
00223
00224 CGradientQETraces *residualGradientTraces;
00225 CGradientQETraces *directGradientTraces;
00226
00227 CGradientQETraces *residualETraces;
00228
00229 CAbstractBetaCalculator *betaCalculator;
00230
00231 virtual void learnStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00232
00233 public:
00234 CTDResidualLearner(CRewardFunction *rewardFunction, CGradientQFunction *qfunction, CAgentController *agent, CResidualFunction *residual, CResidualGradientFunction *residualGradient, CAbstractBetaCalculator *betaCalc);
00235
00236 ~CTDResidualLearner();
00237
00238 void newEpisode();
00239
00240 virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *action, double td);
00241
00242 CGradientQETraces *getResidualETraces() {return residualETraces;};
00243 };
00244
00245
00246
00247 class CQAverageTDErrorLearner : public CErrorListener, public CStateObject
00248 {
00249 protected:
00250 double updateRate;
00251
00252 CFeatureQFunction *averageErrorFunction;
00253 public:
00254 CQAverageTDErrorLearner(CFeatureQFunction *averageErrorFunction, double updateRate);
00255 virtual ~CQAverageTDErrorLearner();
00256
00257 virtual void onParametersChanged();
00258
00259 virtual void receiveError(double error, CStateCollection *state, CAction *action, CActionData *data = NULL);
00260 };
00261
00262 class CQAverageTDVarianceLearner : public CQAverageTDErrorLearner
00263 {
00264 public:
00265
00266 CQAverageTDVarianceLearner(CFeatureQFunction *averageErrorFunction, double updateRate);
00267 virtual ~CQAverageTDVarianceLearner();
00268
00269 virtual void receiveError(double error, CStateCollection *state, CAction *action, CActionData *data = NULL);
00270 };
00271
00272 #endif
00273