00001 #ifndef C_TREEVFUNCTION__H
00002 #define C_TREEVFUNCTION__H
00003
00004 #include "ctrees.h"
00005 #include "cforest.h"
00006 #include "cinputdata.h"
00007 #include "cvfunction.h"
00008 #include "ccontinuousactions.h"
00009 #include "cstatemodifier.h"
00010 #include "cstatecollection.h"
00011 #include "cstate.h"
00012 #include "cstateproperties.h"
00013 #include "caction.h"
00014
00015
00016 class CRegressionTreeFunction
00017 {
00018 protected:
00019 CMapping<double> *tree;
00020
00021 int numDim;
00022 public:
00023 CRegressionTreeFunction(CMapping<double> *tree, int numDim);
00024
00025 virtual ~CRegressionTreeFunction() {};
00026 void setTree(CMapping<double> *tree);
00027 CMapping<double> *getTree();
00028
00029 int getNumDimensions();
00030 virtual void getInputData(CStateCollection *state, CAction *action, ColumnVector *data) = 0;
00031 };
00032
00033 class CRegressionTreeVFunction : public CAbstractVFunction, public CRegressionTreeFunction
00034 {
00035 protected :
00036
00037 public :
00038 CRegressionTreeVFunction(CStateProperties *properties, CMapping<double> *tree);
00039
00040 virtual ~CRegressionTreeVFunction() {};
00041
00042 virtual double getValue(CState *state);
00043 virtual void getInputData(CStateCollection *state, CAction *action, ColumnVector *data);
00044
00045 virtual void resetData();
00046 virtual void saveData(FILE *stream);
00047 };
00048
00049 class CRegressionTreeQFunction : public CContinuousActionQFunction, public CStateObject, public CRegressionTreeFunction
00050 {
00051 protected :
00052
00053 ColumnVector *buffVector;
00054 public :
00055 CRegressionTreeQFunction(CContinuousAction *action, CStateProperties *properties, CMapping<double> *tree);
00056
00057 virtual ~CRegressionTreeQFunction();
00058
00059 virtual double getCAValue(CStateCollection *state, CContinuousActionData *data);
00060
00061 virtual void getInputData(CStateCollection *state, CAction *action, ColumnVector *data);
00062
00063 virtual void resetData();
00064 };
00065
00066 template<typename TreeData> class CForestFeatureCalculator : public CFeatureCalculator
00067 {
00068 protected:
00069 CForest<TreeData> *forest;
00070 CLeaf<TreeData> **activeLeaves;
00071
00072 double getLeafActivationFactor(CState *stateCol, CLeaf<TreeData> *targetState);
00073 public:
00074 CForestFeatureCalculator(CForest<TreeData> *forest, int offsetNumLeaves = 0);
00075
00076 CForestFeatureCalculator(int numFeatures, int numActiveFeatures);
00077
00078
00079 virtual ~CForestFeatureCalculator();
00080
00081 void getModifiedState(CStateCollection *stateCol, CState *targetState);
00082
00083 void setForest(CForest<TreeData> *forest);
00084 };
00085
00086 template<typename TreeData> CForestFeatureCalculator<TreeData>::CForestFeatureCalculator(CForest<TreeData> *l_forest, int offsetNumTrees) : CFeatureCalculator(l_forest->getNumLeaves() + offsetNumTrees, l_forest->getNumTrees())
00087 {
00088 forest = l_forest;
00089 activeLeaves = new CLeaf<TreeData>*[forest->getNumTrees()];
00090 }
00091
00092 template<typename TreeData> CForestFeatureCalculator<TreeData>::CForestFeatureCalculator(int numFeatures, int numActiveFeatures) : CFeatureCalculator(numFeatures, numActiveFeatures)
00093 {
00094 forest = NULL;
00095 activeLeaves = new CLeaf<TreeData>*[getNumActiveFeatures()];
00096 }
00097
00098 template<typename TreeData> CForestFeatureCalculator<TreeData>::~CForestFeatureCalculator()
00099 {
00100 delete [] activeLeaves;
00101
00102 if (forest)
00103 {
00104 delete forest;
00105 }
00106 }
00107
00108
00109 template<typename TreeData> void CForestFeatureCalculator<TreeData>::setForest(CForest<TreeData> *l_forest)
00110 {
00111 forest = l_forest;
00112 }
00113
00114 template<typename TreeData> void CForestFeatureCalculator<TreeData>::getModifiedState(CStateCollection *stateCol, CState *targetState)
00115 {
00116 if (forest == NULL)
00117 {
00118 targetState->resetState();
00119 targetState->setNumActiveContinuousStates(0);
00120 targetState->setNumActiveDiscreteStates(0);
00121 }
00122 else
00123 {
00124 CState *state = stateCol->getState(originalState);
00125
00126 forest->getActiveLeaves( state, activeLeaves);
00127
00128 int leafSum = 0;
00129
00130 targetState->setNumActiveContinuousStates(numActiveFeatures);
00131 targetState->setNumActiveDiscreteStates(numActiveFeatures);
00132
00133 for (unsigned int i = 0; i < numActiveFeatures; i ++)
00134 {
00135 targetState->setDiscreteState(i, activeLeaves[i]->getLeafNumber() + leafSum);
00136 targetState->setContinuousState(i, getLeafActivationFactor(state, activeLeaves[i]));
00137
00138 leafSum += forest->getTree( i)->getNumLeaves();
00139 }
00140 }
00141 }
00142
00143 template<typename TreeData> double CForestFeatureCalculator<TreeData>::getLeafActivationFactor(CState *, CLeaf<TreeData> *)
00144 {
00145 return 1.0 / numActiveFeatures;
00146 }
00147
00148 class CFeatureVRegressionTreeFunction : public CFeatureVFunction
00149 {
00150 protected:
00151
00152 public:
00153 CFeatureVRegressionTreeFunction(CRegressionForest *regTree, CFeatureCalculator *featCalc);
00154
00155 CFeatureVRegressionTreeFunction(int numFeatures);
00156
00157 void setForest(CRegressionForest *regTree, CFeatureCalculator *featCalc);
00158
00159 virtual double getValue(CState *state);
00160
00161 virtual void copy(CLearnDataObject *vFunction);
00162 };
00163
00164 #endif