00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef CTDLEARNERPOLICIES_H
00033 #define CTDLEARNERPOLICIES_H
00034
00035 #include "cagentcontroller.h"
00036 #include "cparameters.h"
00037
00038 class CAbstractFeatureStochasticEstimatedModel;
00039 class CTransitionFunction;
00040 class CAbstractQFunction;
00041 class CActionSet;
00042 class CFeatureList;
00043 class CActionStatistics;
00044 class CAbstractVFunction;
00045 class CQFunctionFromTransitionFunction;
00046 class CStateCollectionImpl;
00047
00048 #include "newmat/newmat.h"
00049
00051
00054 class CQGreedyPolicy : public CAgentController
00055 {
00056 protected:
00057 CAbstractQFunction *qFunction;
00058 CActionSet *availableActions;
00059 public:
00060 CQGreedyPolicy(CActionSet *actions, CAbstractQFunction *qFunction);
00061 ~CQGreedyPolicy();
00062
00064 virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *data = NULL);
00065
00066 };
00067
00069
00073 class CActionDistribution : virtual public CParameterObject
00074 {
00075 public:
00077
00080 virtual void getDistribution(CStateCollection *state, CActionSet *availableActions, double *actionFactors) = 0;
00081 virtual bool isDifferentiable() {return false;};
00082
00084
00085 virtual void getGradientFactors(CStateCollection *state, CAction *usedAction, CActionSet *actions, double *actionFactors, ColumnVector *gradientFactors);
00086 };
00087
00089
00096 class CSoftMaxDistribution : public CActionDistribution
00097 {
00098 protected:
00099 public:
00100
00101 CSoftMaxDistribution(double beta);
00102
00103 virtual void getDistribution(CStateCollection *state, CActionSet *availableActions, double *values);
00104
00105 virtual bool isDifferentiable() {return true;};
00106
00107 virtual void getGradientFactors(CStateCollection *state, CAction *usedAction, CActionSet *actions, double *actionFactors, ColumnVector *gradientFactors);
00108
00109 };
00110
00111 class CAbsoluteSoftMaxDistribution : public CActionDistribution
00112 {
00113 protected:
00114 public:
00115
00116 CAbsoluteSoftMaxDistribution(double maxAbsValue);
00117
00118 virtual void getDistribution(CStateCollection *state, CActionSet *availableActions, double *values);
00119
00120 virtual bool isDifferentiable() {return false;};
00121
00122
00123 };
00124
00126
00130 class CGreedyDistribution : public CActionDistribution
00131 {
00132 public:
00133 virtual void getDistribution(CStateCollection *state, CActionSet *availableActions, double *values);
00134 };
00135
00137
00142 class CEpsilonGreedyDistribution : public CActionDistribution
00143 {
00144 protected:
00145 public:
00146
00147
00148 CEpsilonGreedyDistribution(double epsilon);
00149 virtual void getDistribution(CStateCollection *state, CActionSet *availableActions, double *values);
00150 };
00151
00152
00154
00162 class CStochasticPolicy: public CAgentStatisticController
00163 {
00164 protected:
00166 double *actionValues;
00167 CActionDistribution *distribution;
00168
00169 ColumnVector *gradientFactors;
00170
00171 CFeatureList *actionGradientFeatures;
00172
00173 CActionSet *availableActions;
00174
00176
00179 virtual void getActionStatistics(CStateCollection *, CAction *, CActionStatistics *) {};
00180
00181 public:
00183 CStochasticPolicy(CActionSet *actions, CActionDistribution *distribution);
00184 ~CStochasticPolicy();
00185
00187
00191 virtual void getActionProbabilities(CStateCollection *state, CActionSet *availableActions, double *actionValues, CActionDataSet *actionDataSet = NULL);
00193
00197 virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *dataset, CActionStatistics *stat);
00198
00200 virtual void getActionValues(CStateCollection *state, CActionSet *availableActions, double *actionValues, CActionDataSet *actionDataSet = NULL) = 0;
00201
00202
00203 virtual bool isDifferentiable() {return false;};
00204
00205 virtual void getActionProbabilityGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradientState);
00206 virtual void getActionProbabilityLnGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradientState);
00207
00209
00212 virtual void getActionGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradientState);
00213 };
00214
00216
00222 class CQStochasticPolicy : public CStochasticPolicy
00223 {
00224 protected:
00226 CAbstractQFunction *qfunction;
00228 virtual void getActionStatistics(CStateCollection *state, CAction *action, CActionStatistics *stat);
00229
00230 public:
00231 CQStochasticPolicy(CActionSet *actions, CActionDistribution *distribution, CAbstractQFunction *qfunction);
00232 ~CQStochasticPolicy();
00233
00234 virtual void getActionValues(CStateCollection *state, CActionSet *availableActions, double *actionValues, CActionDataSet *actionDataSet = NULL);
00235
00236 virtual void getActionGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradientState);
00237 virtual bool isDifferentiable();
00238
00239 virtual CAbstractQFunction *getQFunction() {return qfunction;};
00240 };
00241
00242 class CQFunctionFromTransitionFunction;
00243
00244
00246
00257 class CVMStochasticPolicy : public CQStochasticPolicy
00258 {
00259 protected:
00260 CStateCollectionImpl *nextState;
00261 CStateCollectionImpl *intermediateState;
00262
00263 CAbstractVFunction *vFunction;
00264 CQFunctionFromTransitionFunction *qFunctionFromTransitionFunction;
00265 CTransitionFunction *model;
00266 CRewardFunction *reward;
00267 public:
00268
00269 CVMStochasticPolicy(CActionSet *actions, CActionDistribution *distribution, CAbstractVFunction *vFunction, CTransitionFunction *model, CRewardFunction *reward, std::list<CStateModifier *> *modifiers);
00270 ~CVMStochasticPolicy();
00271
00272 virtual void getActionGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradientState);
00273
00274 virtual bool isDifferentiable();
00275 };
00276
00277
00278 #endif
00279