00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef C_VFUNCTIONLEARNER__H
00033 #define C_VFUNCTIONLEARNER__H
00034
00035
00036
00037 #include "cqfunction.h"
00038 #include "cvfunction.h"
00039 #include "cqetraces.h"
00040 #include "cresiduals.h"
00041 #include "ril_debug.h"
00042
00043
00044 #include "cagentlistener.h"
00045 #include "cerrorlistener.h"
00046 #include "cparameters.h"
00047
00048 class CAgentController;
00049 class CAbstractVFunction;
00050 class CAbstractVETraces;
00051 class CActionDataSet;
00052 class CGradientVFunction;
00053 class CGradientVETraces;
00054 class CResidualFunction;
00055 class CResidualGradientFunction;
00056 class CFeatureList;
00057 class CAbstractBetaCalculator;
00058 class CFeatureVFunction;
00059
00061
00067 class CAdaptiveParameterFromValueCalculator : public CAdaptiveParameterBoundedValuesCalculator, public CSemiMDPListener
00068 {
00069 protected:
00070 CAbstractVFunction *vFunction;
00071
00072 int nSteps;
00073 int nStepsPerUpdate;
00074
00075 double value;
00076 public:
00077 CAdaptiveParameterFromValueCalculator(CParameters *targetObject, string targetParameter, CAbstractVFunction *vFunction, int stepsPerUpdate, int functionKind, double param0, double paramScale, double targetMin, double targetMax);
00078 ~CAdaptiveParameterFromValueCalculator();
00079
00080 virtual void nextStep(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00081 virtual void onParametersChanged(){CAdaptiveParameterBoundedValuesCalculator::onParametersChanged();};
00082
00083 virtual void resetCalculator();
00084 };
00085
00087
00100 class CVFunctionLearner : public CSemiMDPRewardListener, public CErrorSender
00101 {
00102 protected:
00104 CAbstractVFunction *vFunction;
00106 CAbstractVETraces *eTraces;
00107
00109 bool bExternETraces;
00110
00112 virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, int duration);
00113
00114 public:
00116 CVFunctionLearner(CRewardFunction *rewardFunction, CAbstractVFunction *vFunction, CAbstractVETraces *eTraces);
00118 CVFunctionLearner(CRewardFunction *rewardFunction, CAbstractVFunction *vFunction);
00119
00120 virtual ~CVFunctionLearner();
00121
00123
00127 virtual double getTemporalDifference(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00128
00130
00133 virtual void updateVFunction(CStateCollection *oldState, CStateCollection *newState, int duration, double td);
00134
00136 virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00137
00139
00144 virtual void intermediateStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00145
00147 virtual void newEpisode();
00149 CAbstractVFunction *getVFunction();
00150
00151 double getLearningRate();
00152 void setLearningRate(double learningRate);
00154 CAbstractVETraces *getVETraces();
00155 };
00156
00157 class CVFunctionGradientLearner : public CVFunctionLearner
00158 {
00159 protected:
00160 CResidualFunction *residual;
00161 CResidualGradientFunction *residualGradientFunction;
00162
00163 CGradientVFunction *gradientVFunction;
00164 CGradientVETraces *gradientETraces;
00165
00166 CFeatureList *oldGradient;
00167 CFeatureList *newGradient;
00168 CFeatureList *residualGradient;
00169
00170 virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, int duration);
00171 public:
00172 CVFunctionGradientLearner(CRewardFunction *rewardFunction, CGradientVFunction *vFunction, CResidualFunction *residual, CResidualGradientFunction *residualGradientFunction);
00173
00174 ~CVFunctionGradientLearner();
00175
00176 virtual double getTemporalDifference(CStateCollection *oldState, CAction *action, double reward, CStateCollection *nextState);
00177 };
00178
00179 class CVFunctionResidualLearner : public CVFunctionGradientLearner
00180 {
00181 protected:
00182 CGradientVETraces *residualGradientTraces;
00183 CGradientVETraces *directGradientTraces;
00184
00185 CGradientVETraces *residualETraces;
00186
00187 CAbstractBetaCalculator *betaCalculator;
00188
00189 virtual void addETraces(CStateCollection *oldState, CStateCollection *newState, int duration, double td);
00190
00191 public:
00192 CVFunctionResidualLearner(CRewardFunction *rewardFunction, CGradientVFunction *vfunction, CResidualFunction *residual, CResidualGradientFunction *residualGradient, CAbstractBetaCalculator *betaCalc);
00193
00194 ~CVFunctionResidualLearner();
00195
00196 virtual void updateVFunction(CStateCollection *oldState, CStateCollection *newState, int duration, double td);
00197
00198 virtual void newEpisode();
00199
00200 CGradientVETraces *getResidualETraces() {return residualETraces;};
00201 };
00202
00203 class CVAverageTDErrorLearner : public CErrorListener, public CStateObject
00204 {
00205 protected:
00206 double updateRate;
00207
00208 CFeatureVFunction *averageErrorFunction;
00209 public:
00210 CVAverageTDErrorLearner(CFeatureVFunction *averageErrorFunction, double updateRate);
00211 virtual ~CVAverageTDErrorLearner();
00212
00213 virtual void receiveError(double error, CStateCollection *state, CAction *action, CActionData *data = NULL);
00214
00215 virtual void onParametersChanged();
00216 };
00217
00218 class CVAverageTDVarianceLearner : public CVAverageTDErrorLearner
00219 {
00220 public:
00221 CVAverageTDVarianceLearner(CFeatureVFunction *averageErrorFunction, double updateRate);
00222 virtual ~CVAverageTDVarianceLearner();
00223
00224 virtual void receiveError(double error, CStateCollection *state, CAction *action, CActionData *data = NULL);
00225 };
00226
00227
00228 #endif
00229