00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef C_ABSTRACTVFUNCTION_H
00033 #define C_ABSTRACTVFUNCTION_H
00034
00035 #include <stdio.h>
00036
00037 #include "clearndataobject.h"
00038 #include "cmyexception.h"
00039 #include "cbaseobjects.h"
00040 #include "cgradientfunction.h"
00041 #include "cfeaturefunction.h"
00042
00043 class CAbstractVETraces;
00044 class CFeatureQFunction;
00045 class CStochasticPolicy;
00046 class CRewardFunction;
00047 class CStateCollectionImpl;
00048 class CFeatureCalculator;
00049 class CAbstractStateDiscretizer;
00050 class CStateReward;
00051
00052 #define DIVERGENTVFUNCTIONVALUE 1000000
00053
00054 #define CONTINUOUSVFUNCTION 1
00055 #define GRADIENTVFUNCTION 2
00056
00057
00079 class CAbstractVFunction : public CStateObject, virtual public CLearnDataObject {
00080 protected:
00081 int type;
00082
00083 void addType(int newType);
00084
00085 public:
00086 bool mayDiverge;
00087
00089 CAbstractVFunction(CStateProperties *properties);
00090
00091 virtual ~CAbstractVFunction();
00092
00093 virtual void resetData() {};
00094
00096 virtual void updateValue(CStateCollection *state, double td);
00098 virtual void setValue(CStateCollection *state, double qValue);
00100 virtual double getValue(CStateCollection *state);
00101
00103 virtual void updateValue(CState *state, double td);
00105 virtual void setValue(CState *, double ) {};
00107 virtual double getValue(CState *state) = 0;
00108
00110 virtual void saveData(FILE *file);
00112 virtual void loadData(FILE *file);
00114 virtual void printValues (){};
00115
00116 int getType();
00117 bool isType(int isT);
00118
00120
00124 virtual CAbstractVETraces *getStandardETraces();
00125 };
00126
00128 class CZeroVFunction : public CAbstractVFunction
00129 {
00130 protected:
00131 public:
00132 CZeroVFunction();
00133
00134 virtual double getValue(CState *state);
00135 };
00136
00137 class CVFunctionSum : public CAbstractVFunction
00138 {
00139 protected:
00140 std::map<CAbstractVFunction *, double> *vFunctions;
00141 public:
00142 CVFunctionSum();
00143 virtual ~CVFunctionSum();
00144
00145
00147 virtual double getValue(CStateCollection *state);
00148 virtual double getValue(CState *state) {return getValue((CStateCollection *) state);};
00149
00150
00151 virtual CAbstractVETraces *getStandardETraces() {return NULL;};
00152
00153 double getVFunctionFactor(CAbstractVFunction *vFunction);
00154 void setVFunctionFactor(CAbstractVFunction *vFunction, double factor);
00155
00156 void addVFunction(CAbstractVFunction *vFunction, double factor = 1.0);
00157 void removeVFunction(CAbstractVFunction *vFunction);
00158
00159 void normFactors(double factor);
00160 };
00161
00163
00166 class CDivergentVFunctionException : public CMyException
00167 {
00168 protected:
00169 virtual string getInnerErrorMsg();
00170 public:
00171 string vFunctionName;
00172 CAbstractVFunction *vFunction;
00173 CState *state;
00174 double value;
00175
00176 CDivergentVFunctionException(string vFunctionName, CAbstractVFunction *vFunction, CState *state, double value);
00177 virtual ~CDivergentVFunctionException(){};
00178 };
00179
00181
00194 class CGradientVFunction : public CAbstractVFunction, virtual public CGradientUpdateFunction
00195 {
00196 protected:
00197
00198 public:
00200 CGradientVFunction(CStateProperties *properties);
00201 virtual ~CGradientVFunction();
00202
00204 virtual void updateValue(CStateCollection *state, double td);
00206 virtual void updateValue(CState *state, double td);
00207
00208 virtual void getGradient(CStateCollection *state, CFeatureList *gradientFeatures) = 0;
00209
00210 virtual void resetData() = 0;
00211 virtual void loadData(FILE *stream) {CGradientUpdateFunction::loadData(stream);};
00212 virtual void saveData(FILE *stream) {CGradientUpdateFunction::saveData(stream);};
00213
00214 virtual CAbstractVETraces *getStandardETraces();
00215
00216
00217 virtual void copy(CLearnDataObject *vFunction) {CGradientUpdateFunction::copy(vFunction);};
00218 };
00219
00220
00221
00222
00223
00224
00225
00226
00227
00229
00230
00231
00232
00233
00234
00235
00236
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00249
00254 class CVFunctionInputDerivationCalculator : virtual public CParameterObject
00255 {
00256 protected:
00257 CStateProperties *modelState;
00258 public:
00259 CVFunctionInputDerivationCalculator(CStateProperties *modelState);
00260
00261 virtual void getInputDerivation( CStateCollection *state, ColumnVector *targetVector) = 0;
00262 unsigned int getNumInputs();
00263 };
00264
00265
00267
00274 class CVFunctionNumericInputDerivationCalculator : public CVFunctionInputDerivationCalculator
00275 {
00276 protected:
00277 CAbstractVFunction *vFunction;
00278 CStateCollectionImpl *stateBuffer;
00279 public:
00280 CVFunctionNumericInputDerivationCalculator(CStateProperties *modelState, CAbstractVFunction *vFunction, double stepSize, std::list<CStateModifier *> *modifiers);
00281 virtual ~CVFunctionNumericInputDerivationCalculator();
00282
00283 virtual void getInputDerivation( CStateCollection *state, ColumnVector *targetVector);
00284 };
00285
00286
00287
00309 class CFeatureVFunction : public CGradientVFunction, public CFeatureFunction
00310 {
00311 protected:
00312
00313
00314 public:
00315 CFeatureVFunction(int numFeatures);
00316
00318
00322 CFeatureVFunction(CStateProperties *featureFact);
00323
00328 CFeatureVFunction(CFeatureQFunction *qfunction, CStochasticPolicy *policy);
00329
00330 ~CFeatureVFunction();
00334 virtual void setVFunctionFromQFunction(CFeatureQFunction *qfunction, CStochasticPolicy *policy);
00335
00336 virtual void updateWeights(CFeatureList *gradientFeatures);
00337
00338
00340
00343 virtual void updateValue(CState *state, double td);
00345
00348 virtual void setValue(CState *state, double qValue);
00350
00353 virtual double getValue(CState *state);
00354
00356 virtual void saveData(FILE *file);
00358 virtual void loadData(FILE *file);
00359 virtual void printValues();
00360
00362 virtual CAbstractVETraces *getStandardETraces();
00363
00364 virtual void getGradient(CStateCollection *state, CFeatureList *gradientFeatures);
00365
00366
00367 virtual int getNumWeights();
00368
00369 virtual void resetData();
00370
00371
00372 virtual void getWeights(double *parameters);
00373 virtual void setWeights(double *parameters);
00374
00375 void setFeatureCalculator(CFeatureCalculator *featCalc);
00376
00377 };
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00398
00399
00400
00401
00402 class CVTable : public CFeatureVFunction
00403 {
00404 public:
00405 CVTable(CAbstractStateDiscretizer *state);
00406
00407 ~CVTable();
00408
00409 void setDiscretizer(CAbstractStateDiscretizer *discretizer);
00410 CAbstractStateDiscretizer *getDiscretizer();
00411
00412 int getNumStates();
00413 };
00414
00415 class CRewardAsVFunction : public CAbstractVFunction
00416 {
00417 protected:
00418 CStateReward *reward;
00419 public:
00420 CRewardAsVFunction(CStateReward *reward);
00421 virtual ~CRewardAsVFunction() {};
00422
00423 virtual double getValue(CState *state);
00424 };
00425
00426
00427 #endif
00428