00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef CRIABSTRACTQFUNCTION_H
00033 #define CRIABSTRACTQFUNCTION_H
00034
00035 #include <stdio.h>
00036
00037 #include "clearndataobject.h"
00038 #include "cbaseobjects.h"
00039 #include "cmyexception.h"
00040 #include "cgradientfunction.h"
00041
00042 class CAbstractFeatureStochasticModel;
00043 class CAbstractQETraces;
00044 class CGradientQETraces;
00045
00046 class CAbstractVFunction;
00047 class CFeatureVFunction;
00048 class CFeatureRewardFunction;
00049 class CActionStatistics;
00050 class CFeature;
00051
00052 #define GRADIENTQFUNCTION 1
00053 #define CONTINUOUSACTIONQFUNCTION 2
00054
00055
00057
00070 class CAbstractQFunction : public CActionObject, virtual public CLearnDataObject
00071 {
00072 protected:
00073 int type;
00074 public:
00075 bool mayDiverge;
00076
00077 int getType();
00078 bool isType(int type);
00079 void addType(int Type);
00080
00081
00083 CAbstractQFunction(CActionSet *actions);
00084 virtual ~CAbstractQFunction();
00085
00086 virtual void saveData(FILE *file);
00087 virtual void loadData(FILE *file);
00088 virtual void printValues (){};
00089
00090 virtual void resetData() {};
00091
00093
00094 void getActionValues(CStateCollection *state, CActionSet *actions, double *actionValues, CActionDataSet *data = NULL);
00095
00097
00099 virtual CAction* getMax(CStateCollection *state, CActionSet *availableActions, CActionDataSet *data = NULL);
00101
00102 virtual double getMaxValue(CStateCollection *state, CActionSet *availableActions);
00104 virtual void getStatistics(CStateCollection *state, CAction *action, CActionSet *actions, CActionStatistics* statistics);
00105
00107 virtual void updateValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00109 virtual void setValue(CStateCollection *state, CAction *action, double qValue, CActionData *data = NULL);
00111 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL) = 0;
00112
00113 virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00114
00115 protected:
00116 };
00117
00118
00119 class CQFunctionSum : public CAbstractQFunction
00120 {
00121 protected:
00122 std::map<CAbstractQFunction *, double> *qFunctions;
00123 public:
00124 CQFunctionSum(CActionSet *actions);
00125 virtual ~CQFunctionSum();
00126
00127
00129 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00130
00131 virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00132
00133 double getQFunctionFactor(CAbstractQFunction *qFunction);
00134 void setQFunctionFactor(CAbstractQFunction *qFunction, double factor);
00135
00136 void addQFunction(CAbstractQFunction *qFunction, double factor);
00137 void removeQFunction(CAbstractQFunction *qFunction);
00138
00139
00140 void normFactors(double factor);
00141
00142 };
00143
00145
00148 class CDivergentQFunctionException : public CMyException
00149 {
00150 protected:
00151 virtual string getInnerErrorMsg();
00152 public:
00153 string qFunctionName;
00154 CAbstractQFunction *qFunction;
00155 CState *state;
00156 double value;
00157
00158 CDivergentQFunctionException(string qFunctionName, CAbstractQFunction *qFunction, CState *state, double value);
00159 virtual ~CDivergentQFunctionException(){};
00160 };
00161
00162 class CGradientQFunction : public CAbstractQFunction, virtual public CGradientUpdateFunction
00163 {
00164 protected:
00165 CFeatureList *localGradientQFunctionFeatures;
00166
00167 public:
00168 CGradientQFunction(CActionSet *actions);
00169 virtual ~CGradientQFunction();
00170
00171 virtual int getWeightsOffset(CAction *) {return 0;};
00172
00173 virtual void getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradient) = 0;
00174
00176 virtual void updateValue(CStateCollection *state, CAction *action, double td, CActionData *data = NULL);
00177
00178 virtual void resetData() {CAbstractQFunction::resetData();};
00179 virtual void loadData(FILE *stream) {CGradientUpdateFunction::loadData(stream);};
00180 virtual void saveData(FILE *stream) {CGradientUpdateFunction::saveData(stream);};
00181
00182 virtual CAbstractQETraces *getStandardETraces();
00183
00184 virtual void copy(CLearnDataObject *qFunction) {CGradientUpdateFunction::copy(qFunction);};
00185 };
00186
00187
00188
00189
00190
00191
00192
00193
00194
00196
00197
00198
00199
00200
00201
00202
00203
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00216
00227 class CQFunction : public CGradientQFunction
00228 {
00229 protected:
00231
00233 std::map<CAction *, CAbstractVFunction *> *vFunctions;
00234
00235 virtual int getWeightsOffset(CAction *action);
00236
00237 virtual void updateWeights(CFeatureList *features);
00238
00239 public:
00241
00244 CQFunction(CActionSet *actions);
00245 virtual ~CQFunction();
00246
00248
00250 virtual void updateValue(CStateCollection *state, CAction *action, double td, CActionData *data = NULL);
00252
00254 virtual void setValue(CStateCollection *state, CAction *action, double qValue, CActionData *data = NULL);
00256
00258 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00259
00261
00264 virtual void updateValue(CState *state, CAction *action, double td, CActionData *data = NULL);
00266
00269 virtual void setValue(CState *state, CAction *action, double qValue, CActionData *data = NULL);
00271
00274 virtual double getValue(CState *state, CAction *action, CActionData *data = NULL);
00275
00277
00280 virtual void saveData(FILE *file);
00282
00284 virtual void loadData(FILE *file);
00286 virtual void printValues();
00287
00289 CAbstractVFunction *getVFunction(CAction *action);
00291 CAbstractVFunction *getVFunction(int index);
00293
00296 void setVFunction(CAction *action, CAbstractVFunction *vfunction, bool bDeleteOld = true);
00298
00301 void setVFunction(int index, CAbstractVFunction *vfunction, bool bDeleteOld = true);
00303 int getNumVFunctions();
00304
00305 virtual CAbstractQETraces *getStandardETraces();
00306
00307
00308 virtual void getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradient);
00309
00310 virtual int getNumWeights();
00311
00312 virtual void getWeights(double *weights);
00313 virtual void setWeights(double *weights);
00314
00315 virtual void resetData();
00316 virtual void copy(CLearnDataObject *qFunction);
00317 };
00318
00320
00328 class CQFunctionFromStochasticModel : public CAbstractQFunction, public CStateObject
00329 {
00330 protected:
00331
00333 CFeatureVFunction *vfunction;
00335 CAbstractFeatureStochasticModel *model;
00337 CStateProperties *discretizer;
00339 CFeatureRewardFunction *rewardfunction;
00340
00342 CState *discState;
00343
00344 public:
00346 CQFunctionFromStochasticModel(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model, CFeatureRewardFunction *rewardfunction);
00347
00348 virtual ~CQFunctionFromStochasticModel();
00349
00350
00351
00352
00354 virtual void updateValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00356 virtual void setValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00357
00359
00360 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00361
00363
00368 virtual double getValue(CState *featState, CAction *action, CActionData *data = NULL);
00370
00373 virtual double getValue(int feature, CAction *action, CActionData *data = NULL);
00374
00375 virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00376 };
00377
00378
00379
00381
00394 class CFeatureQFunction : public CQFunction
00395 {
00396 protected:
00398 CStateProperties *discretizer;
00400 unsigned int features;
00401
00402 std::list<CFeatureVFunction *> *featureVFunctions;
00403
00405 virtual void init();
00406
00408
00412 void initVFunctions(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model, CFeatureRewardFunction *rewardFunction, double gamma);
00413
00414 public:
00416 CFeatureQFunction(CActionSet *actions, CStateProperties *discretizer);
00418
00422 CFeatureQFunction(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model, CFeatureRewardFunction *rewardFunction,double gamma);
00423
00424 virtual ~CFeatureQFunction();
00425
00427
00430 void updateValue(CFeature *state, CAction *action, double td, CActionData *data = NULL);
00432
00435 void setValue(int state, CAction *action, double qValue, CActionData *data = NULL);
00437
00440 double getValue(int feature, CAction *action, CActionData *data = NULL);
00441
00442 void setFeatureCalculator(CStateModifier *discretizer);
00443 CStateProperties *getFeatureCalculator();
00444
00445
00446 int getNumFeatures();
00447
00449
00450 void saveFeatureActionValueTable(FILE *stream);
00452
00453 void saveFeatureActionTable(FILE *stream);
00454 };
00455
00456 class CComposedQFunction : public CGradientQFunction
00457 {
00458 protected:
00459 std::list<CAbstractQFunction *> *qFunctions;
00460
00461 virtual int getWeightsOffset(CAction *action);
00462 virtual void updateWeights(CFeatureList *features);
00463
00464 public:
00465 CComposedQFunction();
00466 virtual ~CComposedQFunction();
00467
00468 virtual void saveData(FILE *file);
00469 virtual void loadData(FILE *file);
00470 virtual void printValues();
00471
00472 virtual void getStatistics(CStateCollection *state, CAction *action, CActionSet *actions, CActionStatistics* statistics);
00473
00475 virtual void updateValue(CStateCollection *state, CAction *action, double td, CActionData *data = NULL);
00477 virtual void setValue(CStateCollection *state, CAction *action, double qValue, CActionData *data = NULL);
00479 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00480
00481 void addQFunction(CAbstractQFunction *qFunction);
00482
00483 std::list<CAbstractQFunction *> *getQFunctions();
00484 int getNumQFunctions();
00485
00486 virtual CAbstractQETraces *getStandardETraces();
00487
00488
00489
00490 virtual void getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradient);
00491
00492
00493 virtual int getNumWeights();
00494 virtual void getWeights(double *weights);
00495 virtual void setWeights(double *weights);
00496
00497 virtual void resetData();
00498 };
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513
00514
00515
00516 #endif
00517
00518
00519