00001 #ifndef C_SAMPLINGBASED_MODEL
00002 #define C_SAMPLINGBASED_MODEL
00003
00004
00005 #include "cagentlistener.h"
00006 #include "cagentcontroller.h"
00007 #include "clearndataobject.h"
00008 #include "cinputdata.h"
00009 #include "cbaseobjects.h"
00010 #include "cparameters.h"
00011 #include "cevaluator.h"
00012
00013 class CEpisodeHistory;
00014 class CAction;
00015 class CTransitionFunction;
00016 class CRewardLogger;
00017 class CActionSet;
00018 class CStateModifier;
00019 class CKDTree;
00020 class CKNearestNeighbors;
00021 class CRangeSearch;
00022 class CKDRectangle;
00023
00024 class CContinuousStateList : public CDataSet, public CSemiMDPListener, public CLearnDataObject, public CStateObject
00025 {
00026 protected:
00027 CEpisodeHistory *initLogger;
00028
00029 CKNearestNeighbors *nearestNeighbor;
00030 CRangeSearch *rangeSearch;
00031
00032 CKDTree *kdTree;
00033
00034
00035 public:
00036
00037
00038 CContinuousStateList(CStateProperties *properties);
00039 virtual ~CContinuousStateList();
00040
00041 virtual void nextStep(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00042
00043 virtual int addState(CState *state, double minDist = -1.0);
00044
00045 virtual void createStateList(CEpisodeHistory *history, bool useInitLogger = false);
00046
00047 virtual void resetData();
00048
00049 virtual void loadData(FILE *);
00050 virtual void saveData(FILE *);
00051
00052 void initNearestNeighborSearch();
00053 void disableNearestNeighborSearch();
00054
00055 virtual bool isMember(ColumnVector *point);
00056 virtual void getNearestNeighbor(ColumnVector *point, int &index, double &distance);
00057
00058 virtual void getSamplesInRange(CKDRectangle *rectangle, DataSubset *subset);
00059
00060 virtual CKDTree *getKDTree() {return kdTree;};
00061
00062 };
00063
00064
00065
00066 class CSampleTransition
00067 {
00068 public:
00069 CSampleTransition(CState *state, CActionSet *availableActions, double reward, CActionData *actionData = NULL);
00070
00071 virtual ~CSampleTransition();
00072
00073 CState *state;
00074 CActionSet *availableActions;
00075 double reward;
00076 CActionData *actionData;
00077 };
00078
00079 class CSamplingBasedTransitionModel : public CActionObject, CSemiMDPRewardListener, public CLearnDataObject, public CStateObject
00080 {
00081 protected:
00082 CEpisodeHistory *initLogger;
00083 CRewardLogger *initRewardLogger;
00084
00085 typedef std::map<CAction *, CSampleTransition *> Transitions;
00086
00087 CRewardFunction *rewardFunction;
00088 CStateProperties *targetProperties;
00089
00090 std::map<int, Transitions *> *transitions;
00091
00092 CContinuousStateList *stateList;
00093
00094 void clearTransitions();
00095 void addTransition(int index, CAction *action, CStateCollection *state, double reward);
00096 public:
00097 CSamplingBasedTransitionModel(CStateProperties *properties, CStateProperties *targetProperties, CActionSet *actions, CRewardFunction *rewardFunction);
00098 virtual ~CSamplingBasedTransitionModel();
00099
00100 virtual void nextStep(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00101
00102 virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *newState);
00103
00104 virtual void resetData();
00105 virtual void loadData(FILE *);
00106 virtual void saveData(FILE *);
00107
00108 int getNumStates();
00109
00110 std::map<CAction *, CSampleTransition *> *getTransitions(int index);
00111
00112 virtual void createStateList(CEpisodeHistory *history, CRewardLogger *logger, bool useInitLogger = false);
00113
00114 CContinuousStateList *getStateList();
00115 };
00116
00117 class CGraphTransition
00118 {
00119 public:
00120 CGraphTransition(int newStateIndex, double reward, double discountFactor, CAction *action, CActionData *actionData = NULL);
00121
00122 virtual ~CGraphTransition();
00123
00124 int newStateIndex;
00125 double reward;
00126 CActionData *actionData;
00127 CAction *action;
00128
00129 double discountFactor;
00130
00131 virtual double getReward();
00132 };
00133
00134
00135 class CSamplingBasedGraph : public CActionObject, public CStateObject, public CLearnDataObject
00136 {
00137 protected:
00138
00139 typedef std::list<CGraphTransition *> Transitions;
00140
00141 std::map<int, Transitions *> *transitions;
00142
00143 CContinuousStateList *stateList;
00144
00145 void clearTransitions();
00146 virtual void addTransition(int index, int newIndex, CAction *action, CActionData *actionData, double reward, double discountFactor);
00147
00148 int numTransitions;
00149 public:
00150 CSamplingBasedGraph(CContinuousStateList *stateList, CActionSet *actions);
00151 virtual ~CSamplingBasedGraph();
00152
00153 virtual void resetData();
00154
00155 int getNumStates();
00156
00157 std::list<CGraphTransition *> *getTransitions(int index);
00158
00159 CContinuousStateList *getStateList();
00160
00161 virtual void loadData(FILE *stream);
00162 virtual void saveData(FILE *stream);
00163
00164 virtual void getConnectedNodes(int node, DataSubset *subset);
00165
00166 virtual void createTransitions();
00167
00168 virtual bool calculateTransition(int startNode, int endNode) = 0;
00169 virtual void getNeighboredNodes(int node, DataSubset *elementList) = 0;
00170
00171 virtual bool isFinalNode(int node) = 0;
00172 virtual void addFinalTransition(int node) = 0;
00173
00174 virtual void addState(CState *addState);
00175 virtual void addTransitions(int node, bool newNode = false);
00176 };
00177
00178
00179
00180 class CGraphTarget
00181 {
00182 protected:
00183 CGraphTarget *nextTarget;
00184 public:
00185 CGraphTarget(CGraphTarget *nextTarget);
00186 virtual ~CGraphTarget();
00187
00188 virtual bool isFinishedCanditate(ColumnVector *node) = 0;
00189 virtual bool isFinished(ColumnVector *oldNode, ColumnVector *newNode, double &reward) = 0;
00190
00191 virtual bool isTargetForState(CStateCollection *state) = 0;
00192
00193 CGraphTarget *getNextTarget() {return nextTarget;};
00194 };
00195
00196 class CGraphTransitionAdaptiveTarget : public CGraphTransition
00197 {
00198 protected:
00199 std::map<CGraphTarget *, double> *targetReward;
00200 std::map<CGraphTarget *, bool> *targetReached;
00201
00202 CGraphTarget **currentTarget;
00203 public:
00204 CGraphTransitionAdaptiveTarget(int newStateIndex, double reward, double discountFactor, CAction *action, CActionData *actionData, CGraphTarget **currentTarget);
00205
00206 virtual ~CGraphTransitionAdaptiveTarget();
00207
00208 virtual double getReward();
00209
00210 virtual double getReward(CGraphTarget *target);
00211 virtual bool isFinished(CGraphTarget *target);
00212 virtual void addTarget(CGraphTarget *target, double reward, bool isFinished);
00213 };
00214
00215 class CAdaptiveTargetGraph : public CSamplingBasedGraph
00216 {
00217 protected:
00218 std::list<CGraphTarget *> *targetList;
00219 CGraphTarget *currentTarget;
00220
00221 virtual void addTransition(int index, int newIndex, CAction *action, CActionData *actionData, double reward, double discountFactor);
00222
00223 virtual void addTargetForNode(CGraphTarget *target, int node);
00224 public:
00225 CAdaptiveTargetGraph(CContinuousStateList *stateList, CActionSet *actions);
00226 virtual ~CAdaptiveTargetGraph();
00227
00228 void setCurrentTarget(CGraphTarget *target);
00229 virtual void addTarget(CGraphTarget *target);
00230
00231
00232 virtual void addState(CState *addState);
00233 };
00234
00235
00236
00237
00238 class CGraphDynamicProgramming;
00239
00240 class CGraphController : public CAgentController
00241 {
00242 protected:
00243 CGraphDynamicProgramming *graph;
00244 public:
00245 CGraphController(CActionSet *actionSet, CGraphDynamicProgramming *graph);
00246 virtual ~CGraphController();
00247
00248 virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *dataSet);
00249 };
00250
00251 class CGraphAdaptiveTargetDynamicProgramming;
00252
00253 class CAdaptiveTargetGraphController : public CAgentController
00254 {
00255 protected:
00256 CGraphAdaptiveTargetDynamicProgramming *adaptiveGraph;
00257
00258 public:
00259 CAdaptiveTargetGraphController(CActionSet *actionSet, CGraphAdaptiveTargetDynamicProgramming *adaptiveGraph);
00260
00261 virtual ~CAdaptiveTargetGraphController();
00262
00263 virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *dataSet);
00264 };
00265
00266
00267 class CPolicyEvaluator;
00268 class CSupervisedLearner;
00269
00270 class CGraphValueFromValueFunctionCalculator : public CEvaluator
00271 {
00272 protected:
00273 CSupervisedLearner *learner;
00274 CGraphDynamicProgramming *graph;
00275
00276 CPolicyEvaluator *evaluator;
00277 public:
00278 CGraphValueFromValueFunctionCalculator(CGraphDynamicProgramming *l_graph, CSupervisedLearner *learner, CPolicyEvaluator *evaluator);
00279 virtual ~CGraphValueFromValueFunctionCalculator();
00280
00281 virtual double evaluate();
00282 };
00283
00284
00285
00286 class CStateCollectionImpl;
00287
00288 class CSamplingBasedTransitionModelFromTransitionFunction : public CSamplingBasedTransitionModel
00289 {
00290 protected:
00291 CTransitionFunction *transitionFunction;
00292
00293 CActionSet *availableActions;
00294 CActionSet *allActions;
00295
00296 CRewardFunction *rewardPrediction;
00297
00298 CStateCollectionImpl *predictState;
00299 public:
00300 CSamplingBasedTransitionModelFromTransitionFunction(CStateProperties *properties, CStateProperties *targetProperties, CActionSet *allActions, CTransitionFunction *transitionFunction, CRewardFunction *rewardFunction, std::list<CStateModifier *> *stateModifier, CRewardFunction *predictReward);
00301
00302 virtual ~CSamplingBasedTransitionModelFromTransitionFunction();
00303
00304 virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *newState);
00305 };
00306
00307
00308 class CGraphDebugger : public CSemiMDPRewardListener
00309 {
00310 protected:
00311 CGraphDynamicProgramming *graph;
00312
00313 double realRewardSum;
00314 double graphRewardSum;
00315
00316 CStateModifier *hcState;
00317
00318 int step;
00319 public:
00320 CGraphDebugger(CGraphDynamicProgramming *graph, CRewardFunction *reward, CStateModifier *hcState);
00321
00322 virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *newState);
00323
00324 virtual void newEpisode();
00325
00326 };
00327
00328 #endif
00329