00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef C_LSTD_H
00033 #define C_LSTD_H
00034
00035 #include "clearndataobject.h"
00036 #include "cagentlistener.h"
00037 #include "csupervisedlearner.h"
00038
00039 #include <stdlib.h>
00040 #include <stdio.h>
00041
00042
00043
00044 class Matrix;
00045 class ColumnVector;
00046
00047 class CFeatureVFunction;
00048 class CFeatureQFunction;
00049 class CStateProperties;
00050 class CFeatureVETraces;
00051 class CFeatureQETraces;
00052 class CFeatureList;
00053
00054 class CGradientQETraces;
00055 class CAgentController;
00056 class CActionDataSet;
00057
00058
00059 class CLSTDLambda : public CSemiMDPRewardListener, public CLearnDataObject, public CLeastSquaresLearner
00060 {
00061 protected:
00062
00063
00064
00065 CFeatureList *oldStateGradient;
00066 CFeatureList *newStateGradient;
00067
00068
00069 int nEpisode;
00070
00071
00072 virtual void getOldGradient(CStateCollection *stateCol, CAction *action, CFeatureList *gradient) = 0;
00073 virtual void getNewGradient(CStateCollection *stateCol, CFeatureList *gradient) = 0;
00074
00075 virtual void updateETraces(CStateCollection *stateCol, CAction *action) = 0;
00076 virtual CFeatureList *getGradientETraces() = 0;
00077 virtual void resetETraces() = 0;
00078
00079 public:
00080 int nUpdateEpisode;
00081
00082 CLSTDLambda(CRewardFunction *rewardFunction, CGradientUpdateFunction *updateFunction, int nUpdatePerEpisode);
00083 virtual ~CLSTDLambda();
00084
00085 virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *newState);
00086 virtual void newEpisode();
00087
00088
00089 virtual void resetData();
00090 virtual void loadData(FILE *stream);
00091 virtual void saveData(FILE *stream);
00092 };
00093
00094 class CVLSTDLambda : public CLSTDLambda
00095 {
00096 protected:
00097 CFeatureVFunction *vFunction;
00098 CFeatureVETraces *vETraces;
00099
00100 virtual void getOldGradient(CStateCollection *stateCol, CAction *action, CFeatureList *gradient);
00101 virtual void getNewGradient(CStateCollection *stateCol, CFeatureList *gradient);
00102
00103 virtual void updateETraces(CStateCollection *stateCol, CAction *action);
00104 virtual CFeatureList *getGradientETraces();
00105 virtual void resetETraces();
00106 public:
00107 CVLSTDLambda(CRewardFunction *rewardFunction, CFeatureVFunction *updateFunction, int nUpdatePerEpisode);
00108 virtual ~CVLSTDLambda();
00109 };
00110
00111 class CQLSTDLambda : public CLSTDLambda
00112 {
00113 protected:
00114 CFeatureQFunction *qFunction;
00115 CGradientQETraces *qETraces;
00116
00117 CAgentController *policy;
00118 CActionDataSet *actionDataSet;
00119
00120 virtual void getOldGradient(CStateCollection *stateCol, CAction *action, CFeatureList *gradient);
00121 virtual void getNewGradient(CStateCollection *stateCol, CFeatureList *gradient);
00122
00123 virtual void updateETraces(CStateCollection *stateCol, CAction *action);
00124 virtual CFeatureList * getGradientETraces();
00125 virtual void resetETraces();
00126 public:
00127 CQLSTDLambda(CRewardFunction *rewardFunction, CFeatureQFunction *updateFunction, CAgentController *policy, int nUpdatePerEpisode);
00128 virtual ~CQLSTDLambda();
00129 };
00130
00131
00132 #endif
00133
00134