00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef __CTransitionFunction_H
00033 #define __CTransitionFunction_H
00034
00035 #include "cparameters.h"
00036 #include "cbaseobjects.h"
00037 #include "crewardfunction.h"
00038 #include "cenvironmentmodel.h"
00039 #include "cqfunction.h"
00040
00041
00042 #define DM_CONTINUOUSMODEL 1
00043 #define DM_DERIVATIONUMODEL 2
00044 #define DM_EXTENDEDACTIONMODEL 4
00045
00046 #define DM_RESET_TYPE_ALL_RANDOM 2
00047 #define DM_RESET_TYPE_RANDOM 1
00048 #define DM_RESET_TYPE_ZERO 0
00049
00050 class CStateCollectionImpl;
00051 class CStateCollectionList;
00052 class CActionDataSet;
00053 class CContinuousActionProperties;
00054 class CContinuousAction;
00055 class CStateList;
00056 class CRegion;
00057 class CAbstractFeatureStochasticModel;
00058 class CVFunctionInputDerivationCalculator;
00059 class CAbstractVFunction;
00060
00061 class CTransitionFunction : public CStateObject, public CActionObject, virtual public CParameterObject
00062 {
00063 protected:
00064 int type;
00065
00066 int resetType;
00067 public:
00068 CTransitionFunction(CStateProperties *properties, CActionSet *actions);
00069
00070 int getType();
00071 void addType(int Type);
00072 bool isType(int type);
00073
00074 virtual void transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data = NULL) = 0;
00075
00076 virtual void getDerivationU(CState *oldstate, Matrix *derivation);
00077
00078 virtual bool isResetState(CState *) {return false;};
00079 virtual bool isFailedState(CState *) {return false;};
00080
00081 virtual void getResetState(CState *resetState);
00082
00083 virtual void setResetType(int resetType);
00084 };
00085
00086 class CExtendedActionTransitionFunction : public CTransitionFunction, public CRewardFunction
00087 {
00088 protected:
00089 CTransitionFunction *dynModel;
00090
00091 CStateCollectionImpl *intermediateState;
00092 CStateCollectionImpl *nextState;
00093
00094 CActionDataSet *actionDataSet;
00095
00096 CRewardFunction *rewardFunction;
00097 double lastReward;
00098 public:
00099 CExtendedActionTransitionFunction(CActionSet *actions, CTransitionFunction *model, std::list<CStateModifier *> *modifiers, CRewardFunction *rewardFunction = NULL) ;
00100 ~CExtendedActionTransitionFunction();
00101
00102 virtual void transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data = NULL);
00103 virtual double transitionFunctionAndReward(CState *oldState, CAction *action, CState *newState, CActionData *data, CRewardFunction *reward, double gamma);
00104
00105 virtual void getDerivationU(CState *oldstate, Matrix *derivation);
00106
00107 virtual bool isResetState(CState *state);
00108 virtual bool isFailedState(CState *state);
00109
00110 virtual void getResetState(CState *resetState);
00111
00112 virtual void setResetType(int resetType);
00113
00114 virtual double getReward(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00115 };
00116
00117
00118 class CComposedTransitionFunction : public CTransitionFunction
00119 {
00120 protected:
00121 std::list<CTransitionFunction *> *TransitionFunction;
00122 public:
00123
00124 CComposedTransitionFunction(CStateProperties *properties);
00125 ~CComposedTransitionFunction();
00126
00127 void addTransitionFunction(CTransitionFunction *model);
00128
00129 virtual void transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data = NULL);
00130 };
00131
00132 class CContinuousTimeTransitionFunction : public CTransitionFunction
00133 {
00134 protected:
00135 double dt;
00136 int simulationSteps;
00137
00138 ColumnVector *derivation;
00139
00140 virtual void doSimulationStep(CState *oldState, double timeStep, CAction *action, CActionData *data);
00141
00142 public:
00143 CContinuousTimeTransitionFunction(CStateProperties *properties, CActionSet *actions, double dt);
00144 virtual ~CContinuousTimeTransitionFunction();
00145
00146 virtual void transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data = NULL);
00147
00148 double getTimeIntervall();
00149 void setTimeIntervall(double dt);
00150
00151 void setSimulationSteps(int steps);
00152 int getSimulationSteps();
00153
00154 virtual void getDerivationX(CState *oldstate, CAction *action, ColumnVector *derivation, CActionData *data = NULL) = 0;
00155 };
00156
00157 class CContinuousAction;
00158 class CContinuousActionData;
00159
00160 class CContinuousTimeAndActionTransitionFunction : public CContinuousTimeTransitionFunction
00161 {
00162 protected:
00163 CContinuousActionProperties *actionProp;
00164 CContinuousAction *contAction;
00165 public:
00166 CContinuousTimeAndActionTransitionFunction(CStateProperties *properties, CContinuousAction *action, double dt);
00167 virtual ~CContinuousTimeAndActionTransitionFunction();
00168
00169 virtual void getDerivationX(CState *oldState, CAction *action, ColumnVector *derivationX, CActionData *data = NULL);
00170 virtual void getCADerivationX(CState *oldState, CContinuousActionData *action, ColumnVector *derivationX) = 0;
00171
00172
00173 CContinuousAction *getContinuousAction();
00174 };
00175
00176
00177 class CLinearActionContinuousTimeTransitionFunction : public CContinuousTimeAndActionTransitionFunction
00178 {
00179 protected:
00180
00181 ColumnVector *A;
00182 Matrix *B;
00183
00184 public:
00185 CLinearActionContinuousTimeTransitionFunction(CStateProperties *properties, CContinuousAction *action, double dt);
00186 ~CLinearActionContinuousTimeTransitionFunction();
00187
00188 virtual void getCADerivationX(CState *oldState, CContinuousActionData *action, ColumnVector *derivationX);
00189
00190 virtual void getDerivationU(CState *oldstate, Matrix *derivation);
00191 virtual Matrix *getB(CState *state) = 0;
00192 virtual ColumnVector *getA(CState *state) = 0;
00193
00194 };
00195
00196 class CDynamicLinearContinuousTimeModel : public CLinearActionContinuousTimeTransitionFunction
00197 {
00198 protected:
00199 Matrix *B;
00200 Matrix *AMatrix;
00201
00202 public:
00203 CDynamicLinearContinuousTimeModel(CStateProperties *properties, CContinuousAction *action, double dt, Matrix *A, Matrix *B);
00204 ~CDynamicLinearContinuousTimeModel();
00205
00206 virtual Matrix *getB(CState *state);
00207 virtual ColumnVector *getA(CState *state);
00208 };
00209
00210
00211 class CTransitionFunctionEnvironment : public CEnvironmentModel
00212 {
00213 protected:
00214 CTransitionFunction *TransitionFunction;
00215 CState *modelState;
00216 CState *nextState;
00217
00218 CStateList *startStates;
00219 int nEpisode;
00220 bool createdStartStates;
00221
00222 CRegion *failedRegion;
00223 CRegion *sampleRegion;
00224 CRegion *targetRegion;
00225 public:
00226 CTransitionFunctionEnvironment(CTransitionFunction *model);
00227 virtual ~CTransitionFunctionEnvironment();
00228
00229 virtual void doNextState(CPrimitiveAction *action);
00230 virtual void doResetModel();
00231
00232 virtual void getState(CState *state);
00233
00234 virtual void setState(CState *state);
00235
00236 virtual void setStartStates(CStateList *startStates);
00237 virtual void setStartStates(char *filename);
00238
00239 CTransitionFunction *getTransitionFunction() {return TransitionFunction;};
00240
00241 void setSampleRegion(CRegion *sampleRegion);
00242 void setFailedRegion(CRegion *failedRegion);
00243 void setTargetRegion(CRegion *sampleRegion);
00244 };
00245
00246 class CTransitionFunctionFromStochasticModel : public CTransitionFunction
00247 {
00248 protected:
00249 CAbstractFeatureStochasticModel *stochasticModel;
00250
00251 std::list<int> *startStates;
00252 std::list<double> *startProbabilities;
00253 std::map<int, double> *endStates;
00254 public:
00255 CTransitionFunctionFromStochasticModel(CStateProperties *properties, CAbstractFeatureStochasticModel *model);
00256 ~CTransitionFunctionFromStochasticModel();
00257
00258 virtual void transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data = NULL);
00259
00260 void addStartState(int state, double probability);
00261 void addEndState(int state, double probability);
00262
00263 virtual bool isResetState(CState *state);
00264 virtual void getResetState(CState *state);
00265 };
00266
00267 class CQFunctionFromTransitionFunction : public CAbstractQFunction, public CStateModifiersObject
00268 {
00269 protected:
00270
00272 CAbstractVFunction *vfunction;
00274 CTransitionFunction *model;
00276 CRewardFunction *rewardfunction;
00277
00279 CStateCollectionImpl *intermediateState;
00280 CStateCollectionImpl *nextState;
00281
00282
00283 CStateCollectionList *stateCollectionList;
00284 CActionDataSet *actionDataSet;
00285
00286 public:
00288
00289 CQFunctionFromTransitionFunction(CActionSet *actions, CAbstractVFunction *vfunction, CTransitionFunction *model, CRewardFunction *rewardfunction, std::list<CStateModifier *> *modifiers);
00290
00291 virtual ~CQFunctionFromTransitionFunction();
00292
00294
00295
00297 virtual void setValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00298
00300 virtual void updateValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00301
00302
00304
00305 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00306
00307 double getValueDepthSearch(CStateCollectionList *state, CAction *action, CActionData *data, int depth);
00308
00309 virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00310
00311 virtual void addStateModifier(CStateModifier *modifier);
00312 };
00313
00314 class CContinuousTimeQFunctionFromTransitionFunction : public CAbstractQFunction, public CStateModifiersObject
00315 {
00316 protected:
00317
00319 CVFunctionInputDerivationCalculator *vfunction;
00321 CContinuousTimeTransitionFunction *model;
00323 CRewardFunction *rewardfunction;
00324
00325 CStateCollectionImpl *nextState;
00326
00327 CState *derivationXModel;
00328 CState *derivationXVFunction;
00329
00330 virtual double getValueVDerivation(CStateCollection *state, CAction *action, CActionData *data, ColumnVector *derivationXVFunction);
00331 public:
00333
00334 CContinuousTimeQFunctionFromTransitionFunction(CActionSet *actions, CVFunctionInputDerivationCalculator *vfunction, CContinuousTimeTransitionFunction *model, CRewardFunction *rewardfunction, std::list<CStateModifier *> *modifiers);
00335
00336 CContinuousTimeQFunctionFromTransitionFunction(CActionSet *actions, CVFunctionInputDerivationCalculator *vfunction, CContinuousTimeTransitionFunction *model, CRewardFunction *rewardfunction);
00337
00338 virtual ~CContinuousTimeQFunctionFromTransitionFunction();
00339
00340 virtual void getActionValues(CStateCollection *state, CActionSet *actions, double *actionValues, CActionDataSet *actionDataSet);
00341
00343 virtual void setValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00344
00346 virtual void updateValue(CStateCollection *, CAction *, double , CActionData * = NULL) {};
00347
00348
00350
00351 virtual double getValue(CStateCollection *state, CAction *action, CActionData *data = NULL);
00352
00353
00354 virtual CAbstractQETraces *getStandardETraces() {return NULL;};
00355
00356 virtual void addStateModifier(CStateModifier *modifier);
00357
00358 };
00359
00360
00361
00362 #endif
00363