Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

ctreevfunction.h

Go to the documentation of this file.
00001 #ifndef C_TREEVFUNCTION__H
00002 #define C_TREEVFUNCTION__H
00003 
00004 #include "ctrees.h"
00005 #include "cforest.h"
00006 #include "cinputdata.h"
00007 #include "cvfunction.h"
00008 #include "ccontinuousactions.h"
00009 #include "cstatemodifier.h"
00010 #include "cstatecollection.h"
00011 #include "cstate.h"
00012 #include "cstateproperties.h"
00013 #include "caction.h"
00014 
00015 
00016 class CRegressionTreeFunction
00017 {
00018 protected:
00019         CMapping<double> *tree;
00020         
00021         int numDim;
00022 public:
00023         CRegressionTreeFunction(CMapping<double> *tree, int numDim);
00024 
00025         virtual ~CRegressionTreeFunction() {};
00026         void setTree(CMapping<double> *tree);
00027         CMapping<double> *getTree();
00028 
00029         int getNumDimensions();
00030         virtual void getInputData(CStateCollection *state, CAction *action, ColumnVector *data) = 0;
00031 };
00032 
00033 class CRegressionTreeVFunction : public CAbstractVFunction, public CRegressionTreeFunction
00034 {
00035         protected :
00036         
00037         public :
00038                 CRegressionTreeVFunction(CStateProperties *properties, CMapping<double> *tree);
00039 
00040                 virtual ~CRegressionTreeVFunction() {};
00041 
00042                 virtual double getValue(CState *state);
00043                 virtual void getInputData(CStateCollection *state, CAction *action, ColumnVector *data);
00044 
00045                 virtual void resetData();
00046                 virtual void saveData(FILE *stream);
00047 };
00048 
00049 class CRegressionTreeQFunction : public CContinuousActionQFunction, public CStateObject, public CRegressionTreeFunction
00050 {
00051         protected :
00052         
00053                 ColumnVector *buffVector;
00054         public :
00055                 CRegressionTreeQFunction(CContinuousAction *action, CStateProperties *properties, CMapping<double> *tree);
00056 
00057                 virtual ~CRegressionTreeQFunction();
00058 
00059                 virtual double getCAValue(CStateCollection *state, CContinuousActionData *data);
00060 
00061                 virtual void getInputData(CStateCollection *state, CAction *action, ColumnVector *data);
00062 
00063                 virtual void resetData();
00064 };
00065 
00066 template<typename TreeData> class CForestFeatureCalculator : public CFeatureCalculator
00067 {
00068 protected:
00069         CForest<TreeData> *forest;
00070         CLeaf<TreeData> **activeLeaves;
00071 
00072         double  getLeafActivationFactor(CState *stateCol, CLeaf<TreeData> *targetState);
00073 public:
00074         CForestFeatureCalculator(CForest<TreeData> *forest, int offsetNumLeaves = 0);
00075         
00076         CForestFeatureCalculator(int numFeatures, int numActiveFeatures);
00077         
00078 
00079         virtual ~CForestFeatureCalculator();
00080 
00081         void getModifiedState(CStateCollection *stateCol, CState *targetState);
00082 
00083         void setForest(CForest<TreeData> *forest);
00084 };
00085 
00086 template<typename TreeData> CForestFeatureCalculator<TreeData>::CForestFeatureCalculator(CForest<TreeData> *l_forest, int offsetNumTrees) : CFeatureCalculator(l_forest->getNumLeaves() + offsetNumTrees, l_forest->getNumTrees())
00087 {
00088         forest = l_forest;
00089         activeLeaves = new CLeaf<TreeData>*[forest->getNumTrees()];
00090 }
00091 
00092 template<typename TreeData> CForestFeatureCalculator<TreeData>::CForestFeatureCalculator(int numFeatures, int numActiveFeatures) : CFeatureCalculator(numFeatures, numActiveFeatures)
00093 {
00094         forest = NULL;
00095         activeLeaves = new CLeaf<TreeData>*[getNumActiveFeatures()];
00096 }
00097 
00098 template<typename TreeData> CForestFeatureCalculator<TreeData>::~CForestFeatureCalculator()
00099 {
00100         delete [] activeLeaves;
00101 
00102         if (forest)
00103         {
00104                 delete forest;
00105         }
00106 }
00107 
00108 
00109 template<typename TreeData> void CForestFeatureCalculator<TreeData>::setForest(CForest<TreeData> *l_forest)
00110 {
00111         forest = l_forest;      
00112 }
00113 
00114 template<typename TreeData> void  CForestFeatureCalculator<TreeData>::getModifiedState(CStateCollection *stateCol, CState *targetState)
00115 {
00116         if (forest == NULL)
00117         {
00118                 targetState->resetState();
00119                 targetState->setNumActiveContinuousStates(0);
00120                 targetState->setNumActiveDiscreteStates(0);
00121         }
00122         else
00123         {       
00124                 CState *state = stateCol->getState(originalState);
00125         
00126                 forest->getActiveLeaves( state, activeLeaves);
00127         
00128                 int leafSum = 0;
00129         
00130                 targetState->setNumActiveContinuousStates(numActiveFeatures);
00131                 targetState->setNumActiveDiscreteStates(numActiveFeatures);
00132                 
00133                 for (unsigned int i = 0; i < numActiveFeatures; i ++)
00134                 {
00135                         targetState->setDiscreteState(i, activeLeaves[i]->getLeafNumber() + leafSum);
00136                         targetState->setContinuousState(i, getLeafActivationFactor(state, activeLeaves[i]));
00137         
00138                         leafSum += forest->getTree( i)->getNumLeaves();
00139                 }
00140         }
00141 }
00142 
00143 template<typename TreeData> double  CForestFeatureCalculator<TreeData>::getLeafActivationFactor(CState *, CLeaf<TreeData> *)
00144 {
00145         return 1.0 / numActiveFeatures;
00146 }
00147 
00148 class CFeatureVRegressionTreeFunction : public CFeatureVFunction
00149 {
00150 protected:
00151 
00152 public:
00153         CFeatureVRegressionTreeFunction(CRegressionForest *regTree, CFeatureCalculator *featCalc);
00154 
00155         CFeatureVRegressionTreeFunction(int numFeatures);
00156 
00157         void setForest(CRegressionForest *regTree, CFeatureCalculator *featCalc);
00158         
00159         virtual double getValue(CState *state);
00160 
00161         virtual void copy(CLearnDataObject *vFunction);
00162 };
00163 
00164 #endif