00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef C_FOREST__H
00014 #define C_FOREST__H
00015
00016 #include "ctrees.h"
00017 #include "cinputdata.h"
00018
00020
00021 template <typename TreeData> class CForest
00022 {
00023 protected:
00024
00025 CTree<TreeData> **forest;
00026 int numTrees;
00027 public:
00028 CForest(int numTrees);
00029 virtual ~CForest();
00030
00031 virtual void getTreeDatas(ColumnVector *vector, TreeData *outputs);
00032
00033 virtual void addTree(int index, CTree<TreeData> *tree);
00034 virtual void removeTree(int index);
00035
00036 CTree<TreeData> *getTree(int index);
00037
00038 int getNumTrees();
00039
00040 double getAverageDepth();
00041 double getAverageNumLeaves();
00042
00043 virtual void getActiveLeafNumbers(ColumnVector *vector, int *leafNumbers);
00044 virtual void getActiveLeaves(ColumnVector *vector, CLeaf<TreeData> **leafNumbers);
00045
00046 int getNumLeaves();
00047 };
00048
00049 template<typename TreeData> CForest<TreeData>::CForest(int l_numTrees)
00050 {
00051 numTrees = l_numTrees;
00052 forest = new CTree<TreeData> *[numTrees];
00053
00054 for (int i = 0; i < numTrees; i++)
00055 {
00056 forest[i] = NULL;
00057 }
00058 }
00059
00060 template<typename TreeData> CForest<TreeData>::~CForest()
00061 {
00062 delete [] forest;
00063 }
00064
00065 template<typename TreeData> void CForest<TreeData>::getActiveLeafNumbers(ColumnVector *vector, int *leafNumbers)
00066 {
00067 int leaveSum = 0;
00068
00069 for (int i = 0; i < numTrees; i ++)
00070 {
00071 leafNumbers[i] = forest[i]->getLeaf(vector)->getLeafNumber() + leaveSum;
00072 leaveSum += forest[i]->getNumLeaves();
00073 }
00074 }
00075
00076
00077 template<typename TreeData> void CForest<TreeData>::getActiveLeaves(ColumnVector *vector, CLeaf<TreeData> **leaves)
00078 {
00079 for (int i = 0; i < numTrees; i ++)
00080 {
00081 leaves[i] = forest[i]->getLeaf(vector);
00082 }
00083 }
00084
00085 template<typename TreeData> int CForest<TreeData>::getNumLeaves()
00086 {
00087 int numLeaves = 0;
00088 for (int i = 0; i < numTrees; i ++)
00089 {
00090 numLeaves += forest[i]->getNumLeaves();
00091 }
00092 return numLeaves;
00093 }
00094
00095
00096 template<typename TreeData> double CForest<TreeData>::getAverageDepth()
00097 {
00098 double depth = 0;
00099 for (int i = 0; i < numTrees; i ++)
00100 {
00101 depth += forest[i]->getDepth();
00102 }
00103 return depth / numTrees;
00104 }
00105
00106 template<typename TreeData> double CForest<TreeData>::getAverageNumLeaves()
00107 {
00108
00109 double leaves = 0;
00110 for (int i = 0; i < numTrees; i ++)
00111 {
00112 leaves += forest[i]->getNumLeaves();
00113 }
00114 return leaves / numTrees;
00115 }
00116
00117
00118
00119 template<typename TreeData> CTree<TreeData> *CForest<TreeData>::getTree(int index)
00120 {
00121 return forest[index];
00122 }
00123
00124 template<typename TreeData> void CForest<TreeData>::getTreeDatas(ColumnVector *vector, TreeData *outputs)
00125 {
00126 for (int i = 0; i < numTrees; i++)
00127 {
00128 if (forest[i] != NULL)
00129 {
00130 TreeData element = forest[i]->getOutputValue(vector);
00131 outputs[i] = element;
00132 }
00133 }
00134 }
00135
00136 template<typename TreeData> void CForest<TreeData>::addTree(int index, CTree<TreeData> *tree)
00137 {
00138 forest[index] = tree;
00139 }
00140
00141
00142 template<typename TreeData> void CForest<TreeData>::removeTree(int index)
00143 {
00144 forest[index] = NULL;;
00145 }
00146
00147 template<typename TreeData> int CForest<TreeData>::getNumTrees()
00148 {
00149 return numTrees;
00150 }
00151
00153
00154 class CRegressionForest : public CForest<double>, public CMapping<double>
00155 {
00156 protected:
00157 virtual double doGetOutputValue(ColumnVector *vector);
00158 public:
00159 CRegressionForest(int numTrees, int numDim);
00160 virtual ~CRegressionForest();
00161
00162
00163
00164 virtual void saveASCII(FILE *stream);
00165 };
00166
00167 class CExtraTreeRegressionForest : public CRegressionForest
00168 {
00169 protected:
00170
00171 public:
00172 CExtraTreeRegressionForest(int numTrees, CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned int n_min, double treshold, CDataSet1D *weightData = NULL);
00173 virtual ~CExtraTreeRegressionForest();
00174
00175 };
00176
00177 class CRegressionMultiMapping : public CMapping<double>
00178 {
00179 protected:
00180 CMapping<double> **mappings;
00181 int numMappings;
00182
00183
00184
00185 double doGetOutputValue(ColumnVector *inputVector);
00186 public:
00187 bool deleteMappings;
00188
00189 CRegressionMultiMapping(int numMappings, int numDimensions);
00190 virtual ~CRegressionMultiMapping();
00191
00192 void addMapping(int index, CMapping<double> *mapping);
00193 };
00194
00195 #endif