Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

csamplingbasedmodel.h

Go to the documentation of this file.
00001 #ifndef C_SAMPLINGBASED_MODEL
00002 #define C_SAMPLINGBASED_MODEL
00003 
00004 
00005 #include "cagentlistener.h"
00006 #include "cagentcontroller.h"
00007 #include "clearndataobject.h"
00008 #include "cinputdata.h"
00009 #include "cbaseobjects.h"
00010 #include "cparameters.h"
00011 #include "cevaluator.h"
00012 
00013 class CEpisodeHistory;
00014 class CAction;
00015 class CTransitionFunction;
00016 class CRewardLogger;
00017 class CActionSet;
00018 class CStateModifier;
00019 class CKDTree;
00020 class CKNearestNeighbors;
00021 class CRangeSearch;
00022 class CKDRectangle;
00023 
00024 class CContinuousStateList : public CDataSet, public CSemiMDPListener, public CLearnDataObject, public CStateObject 
00025 {
00026 protected:
00027         CEpisodeHistory *initLogger;
00028         
00029         CKNearestNeighbors *nearestNeighbor;
00030         CRangeSearch *rangeSearch;
00031 
00032         CKDTree *kdTree;
00033         
00034 
00035 public:
00036         
00037 
00038         CContinuousStateList(CStateProperties *properties);
00039         virtual ~CContinuousStateList();
00040 
00041         virtual void nextStep(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00042         
00043         virtual int addState(CState *state, double minDist = -1.0);     
00044 
00045         virtual void createStateList(CEpisodeHistory *history, bool useInitLogger = false);
00046 
00047         virtual void resetData();
00048 
00049         virtual void loadData(FILE *);
00050         virtual void saveData(FILE *);
00051 
00052         void initNearestNeighborSearch();
00053         void disableNearestNeighborSearch();
00054 
00055         virtual bool isMember(ColumnVector *point);
00056         virtual void getNearestNeighbor(ColumnVector *point, int &index, double &distance);
00057 
00058         virtual void getSamplesInRange(CKDRectangle *rectangle, DataSubset *subset);
00059 
00060         virtual CKDTree *getKDTree() {return kdTree;};
00061 
00062 };
00063 
00064 
00065 
00066 class CSampleTransition
00067 {
00068         public:
00069                 CSampleTransition(CState *state, CActionSet *availableActions, double reward, CActionData *actionData = NULL);
00070 
00071                 virtual ~CSampleTransition();   
00072 
00073                 CState *state;
00074                 CActionSet *availableActions;
00075                 double reward;
00076                 CActionData *actionData;
00077 };
00078 
00079 class CSamplingBasedTransitionModel : public CActionObject, CSemiMDPRewardListener, public CLearnDataObject, public CStateObject  
00080 {
00081 protected:
00082         CEpisodeHistory *initLogger;
00083         CRewardLogger *initRewardLogger;
00084 
00085         typedef std::map<CAction *, CSampleTransition *> Transitions;
00086 
00087         CRewardFunction *rewardFunction;
00088         CStateProperties *targetProperties;             
00089 
00090         std::map<int, Transitions *> *transitions;
00091         
00092         CContinuousStateList *stateList;
00093                 
00094         void clearTransitions();
00095         void addTransition(int index, CAction *action, CStateCollection *state, double reward);
00096 public:
00097         CSamplingBasedTransitionModel(CStateProperties *properties, CStateProperties *targetProperties, CActionSet *actions, CRewardFunction *rewardFunction);
00098         virtual ~CSamplingBasedTransitionModel();
00099 
00100         virtual void nextStep(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00101         
00102         virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *newState);  
00103 
00104         virtual void resetData();
00105         virtual void loadData(FILE *);
00106         virtual void saveData(FILE *);
00107 
00108         int getNumStates();
00109 
00110         std::map<CAction *, CSampleTransition *> *getTransitions(int index);
00111 
00112         virtual void createStateList(CEpisodeHistory *history, CRewardLogger *logger, bool useInitLogger = false);
00113 
00114         CContinuousStateList *getStateList();
00115 };
00116 
00117 class CGraphTransition
00118 {
00119 public:
00120         CGraphTransition(int newStateIndex, double reward, double discountFactor,  CAction *action, CActionData *actionData = NULL);
00121 
00122         virtual ~CGraphTransition();    
00123 
00124         int newStateIndex;
00125         double reward;
00126         CActionData *actionData;
00127         CAction *action;
00128 
00129         double discountFactor;
00130 
00131         virtual double getReward();
00132 };
00133 
00134 
00135 class CSamplingBasedGraph : public CActionObject, public CStateObject, public CLearnDataObject
00136 {
00137 protected:
00138         
00139         typedef std::list<CGraphTransition *> Transitions;
00140 
00141         std::map<int, Transitions *> *transitions;
00142         
00143         CContinuousStateList *stateList;
00144                 
00145         void clearTransitions();
00146         virtual void addTransition(int index, int newIndex, CAction *action,  CActionData *actionData, double reward, double discountFactor);
00147 
00148         int numTransitions;
00149 public:
00150         CSamplingBasedGraph(CContinuousStateList *stateList, CActionSet *actions);
00151         virtual ~CSamplingBasedGraph();
00152 
00153         virtual void resetData();
00154 
00155         int getNumStates();
00156 
00157         std::list<CGraphTransition *> *getTransitions(int index);
00158 
00159         CContinuousStateList *getStateList();
00160 
00161         virtual void loadData(FILE *stream);
00162         virtual void saveData(FILE *stream);
00163 
00164         virtual void getConnectedNodes(int node, DataSubset *subset);
00165 
00166         virtual void createTransitions(); 
00167 
00168         virtual bool calculateTransition(int startNode, int endNode) = 0;
00169         virtual void getNeighboredNodes(int node, DataSubset *elementList) = 0;
00170 
00171         virtual bool isFinalNode(int node) = 0;
00172         virtual void addFinalTransition(int node) = 0;
00173 
00174         virtual void addState(CState *addState);
00175         virtual void addTransitions(int node, bool newNode = false);
00176 };
00177 
00178 
00179 
00180 class CGraphTarget
00181 {
00182 protected:
00183         CGraphTarget *nextTarget;
00184 public:
00185         CGraphTarget(CGraphTarget *nextTarget);
00186         virtual ~CGraphTarget();
00187 
00188         virtual bool isFinishedCanditate(ColumnVector *node) = 0;
00189         virtual bool isFinished(ColumnVector *oldNode, ColumnVector *newNode, double &reward) = 0;
00190 
00191         virtual bool isTargetForState(CStateCollection *state) = 0;
00192 
00193         CGraphTarget *getNextTarget() {return nextTarget;};
00194 };
00195 
00196 class CGraphTransitionAdaptiveTarget : public CGraphTransition
00197 {
00198 protected:
00199         std::map<CGraphTarget *, double> *targetReward;
00200         std::map<CGraphTarget *, bool> *targetReached;
00201         
00202         CGraphTarget **currentTarget;
00203 public:
00204         CGraphTransitionAdaptiveTarget(int newStateIndex, double reward, double discountFactor,  CAction *action, CActionData *actionData, CGraphTarget **currentTarget);
00205 
00206         virtual ~CGraphTransitionAdaptiveTarget();
00207 
00208         virtual double getReward();
00209 
00210         virtual double getReward(CGraphTarget *target);
00211         virtual bool isFinished(CGraphTarget *target);
00212         virtual void addTarget(CGraphTarget *target, double reward, bool isFinished);
00213 };
00214 
00215 class CAdaptiveTargetGraph : public CSamplingBasedGraph
00216 {
00217 protected:
00218         std::list<CGraphTarget *> *targetList;  
00219         CGraphTarget *currentTarget;
00220 
00221         virtual void addTransition(int index, int newIndex, CAction *action,  CActionData *actionData, double reward, double discountFactor);
00222 
00223         virtual void addTargetForNode(CGraphTarget *target, int node);
00224 public:
00225         CAdaptiveTargetGraph(CContinuousStateList *stateList, CActionSet *actions);
00226         virtual ~CAdaptiveTargetGraph();
00227 
00228         void setCurrentTarget(CGraphTarget *target);
00229         virtual void addTarget(CGraphTarget *target);
00230         
00231 
00232         virtual void addState(CState *addState);
00233 };
00234 
00235 
00236 
00237 
00238 class CGraphDynamicProgramming;
00239 
00240 class CGraphController : public CAgentController
00241 {
00242 protected:
00243         CGraphDynamicProgramming *graph;
00244 public:
00245         CGraphController(CActionSet *actionSet, CGraphDynamicProgramming *graph);
00246         virtual ~CGraphController();
00247 
00248         virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *dataSet);
00249 };
00250 
00251 class CGraphAdaptiveTargetDynamicProgramming;
00252 
00253 class CAdaptiveTargetGraphController : public CAgentController
00254 {
00255 protected:
00256         CGraphAdaptiveTargetDynamicProgramming *adaptiveGraph;
00257 
00258 public:
00259         CAdaptiveTargetGraphController(CActionSet *actionSet, CGraphAdaptiveTargetDynamicProgramming *adaptiveGraph);
00260 
00261         virtual ~CAdaptiveTargetGraphController();
00262         
00263         virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *dataSet);
00264 };
00265 
00266 
00267 class CPolicyEvaluator;
00268 class CSupervisedLearner;
00269 
00270 class CGraphValueFromValueFunctionCalculator : public CEvaluator
00271 {
00272 protected:
00273         CSupervisedLearner *learner;
00274         CGraphDynamicProgramming *graph;
00275         
00276         CPolicyEvaluator *evaluator;
00277 public:
00278         CGraphValueFromValueFunctionCalculator(CGraphDynamicProgramming *l_graph, CSupervisedLearner *learner, CPolicyEvaluator *evaluator);
00279         virtual ~CGraphValueFromValueFunctionCalculator();
00280     
00281         virtual double evaluate();
00282 };
00283 
00284 
00285 
00286 class CStateCollectionImpl;
00287 
00288 class CSamplingBasedTransitionModelFromTransitionFunction : public CSamplingBasedTransitionModel
00289 {
00290 protected:
00291         CTransitionFunction *transitionFunction;
00292         
00293         CActionSet *availableActions;
00294         CActionSet *allActions;
00295 
00296         CRewardFunction *rewardPrediction;
00297 
00298         CStateCollectionImpl *predictState;
00299 public:
00300         CSamplingBasedTransitionModelFromTransitionFunction(CStateProperties *properties, CStateProperties *targetProperties, CActionSet *allActions, CTransitionFunction *transitionFunction, CRewardFunction *rewardFunction, std::list<CStateModifier *> *stateModifier, CRewardFunction *predictReward);
00301 
00302         virtual ~CSamplingBasedTransitionModelFromTransitionFunction();
00303 
00304         virtual void nextStep(CStateCollection *oldState, CAction *action, double reward,  CStateCollection *newState); 
00305 };
00306 
00307 
00308 class CGraphDebugger : public CSemiMDPRewardListener
00309 {
00310 protected:
00311         CGraphDynamicProgramming *graph;
00312 
00313         double realRewardSum;
00314         double graphRewardSum;
00315 
00316         CStateModifier *hcState;
00317 
00318         int step;
00319 public:
00320         CGraphDebugger(CGraphDynamicProgramming *graph, CRewardFunction *reward, CStateModifier *hcState);
00321 
00322         virtual void nextStep(CStateCollection *oldState, CAction *action, double reward, CStateCollection *newState);
00323 
00324         virtual void newEpisode();
00325 
00326 };
00327 
00328 #endif
00329