Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

cforest.h

Go to the documentation of this file.
00001 //
00002 // C++ Interface: cforest
00003 //
00004 // Description: 
00005 //
00006 //
00007 // Author: Neumann Gerhard <gerhard@tu-graz.ac.at>, (C) 2006
00008 //
00009 // Copyright: See COPYING file that comes with this distribution
00010 //
00011 //
00012 
00013 #ifndef C_FOREST__H
00014 #define C_FOREST__H
00015 
00016 #include "ctrees.h"
00017 #include "cinputdata.h"
00018 
00020 
00021 template <typename TreeData> class CForest 
00022 {
00023         protected:
00024                 
00025                 CTree<TreeData> **forest;
00026                 int numTrees;   
00027         public:
00028                 CForest(int numTrees);
00029                 virtual ~CForest();
00030                 
00031                 virtual void getTreeDatas(ColumnVector *vector, TreeData *outputs);
00032                 
00033                 virtual void addTree(int index, CTree<TreeData> *tree);
00034                 virtual void removeTree(int index);
00035 
00036                  CTree<TreeData> *getTree(int index);           
00037 
00038                 int getNumTrees();
00039 
00040                 double getAverageDepth();
00041                 double getAverageNumLeaves();
00042 
00043                 virtual void getActiveLeafNumbers(ColumnVector *vector, int *leafNumbers);
00044                 virtual void getActiveLeaves(ColumnVector *vector, CLeaf<TreeData> **leafNumbers);
00045 
00046                 int getNumLeaves();
00047 };
00048 
00049 template<typename TreeData> CForest<TreeData>::CForest(int l_numTrees)
00050 {
00051         numTrees = l_numTrees;
00052         forest = new CTree<TreeData> *[numTrees];
00053         
00054         for (int i = 0; i < numTrees; i++)
00055         {
00056                 forest[i] = NULL;
00057         }
00058 }
00059 
00060 template<typename TreeData> CForest<TreeData>::~CForest()
00061 {
00062         delete [] forest;
00063 }
00064                 
00065 template<typename TreeData> void  CForest<TreeData>::getActiveLeafNumbers(ColumnVector *vector, int *leafNumbers)
00066 {
00067         int leaveSum = 0;
00068 
00069         for (int i = 0; i < numTrees; i ++)
00070         {
00071                 leafNumbers[i] = forest[i]->getLeaf(vector)->getLeafNumber() + leaveSum;
00072                 leaveSum += forest[i]->getNumLeaves();
00073         }
00074 }
00075 
00076 
00077 template<typename TreeData> void  CForest<TreeData>::getActiveLeaves(ColumnVector *vector, CLeaf<TreeData> **leaves)
00078 {
00079         for (int i = 0; i < numTrees; i ++)
00080         {
00081                 leaves[i] = forest[i]->getLeaf(vector);
00082         }
00083 }
00084 
00085 template<typename TreeData> int CForest<TreeData>::getNumLeaves()
00086 {
00087         int numLeaves = 0;      
00088         for (int i = 0; i < numTrees; i ++)
00089         {
00090                 numLeaves += forest[i]->getNumLeaves();
00091         }
00092         return numLeaves;
00093 }
00094 
00095 
00096 template<typename TreeData> double CForest<TreeData>::getAverageDepth()
00097 {
00098         double depth = 0;
00099         for (int i = 0; i < numTrees; i ++)
00100         {
00101                 depth += forest[i]->getDepth();
00102         }
00103         return depth / numTrees;
00104 }
00105 
00106 template<typename TreeData> double CForest<TreeData>::getAverageNumLeaves()
00107 {
00108 
00109         double leaves = 0;
00110         for (int i = 0; i < numTrees; i ++)
00111         {
00112                 leaves += forest[i]->getNumLeaves();
00113         }
00114         return leaves / numTrees;
00115 }
00116 
00117 
00118 
00119 template<typename TreeData> CTree<TreeData> *CForest<TreeData>::getTree(int index)
00120 {
00121         return forest[index];
00122 }
00123 
00124 template<typename TreeData>  void CForest<TreeData>::getTreeDatas(ColumnVector *vector, TreeData *outputs)
00125 {
00126         for (int i = 0; i < numTrees; i++)
00127         {
00128                 if (forest[i] != NULL)
00129                 {
00130                         TreeData element = forest[i]->getOutputValue(vector);
00131                         outputs[i] = element; 
00132                 }
00133         }
00134 }
00135                 
00136 template<typename TreeData> void CForest<TreeData>::addTree(int index, CTree<TreeData> *tree)
00137 {
00138         forest[index] = tree;
00139 }
00140 
00141 
00142 template<typename TreeData> void CForest<TreeData>::removeTree(int index)
00143 {
00144         forest[index] = NULL;;
00145 }
00146 
00147 template<typename TreeData> int CForest<TreeData>::getNumTrees()
00148 {
00149         return numTrees;
00150 }
00151 
00153 
00154 class CRegressionForest : public CForest<double>, public CMapping<double>
00155 {
00156         protected:
00157                 virtual double doGetOutputValue(ColumnVector *vector);
00158         public:
00159                 CRegressionForest(int numTrees, int numDim);
00160                 virtual ~CRegressionForest();
00161                 
00162                 
00163 
00164                 virtual void saveASCII(FILE *stream);
00165 };
00166 
00167 class CExtraTreeRegressionForest : public CRegressionForest
00168 {
00169         protected:
00170 
00171         public:
00172                 CExtraTreeRegressionForest(int numTrees, CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned  int n_min, double treshold, CDataSet1D *weightData = NULL);
00173                 virtual ~CExtraTreeRegressionForest();
00174 
00175 };
00176 
00177 class CRegressionMultiMapping : public CMapping<double>
00178 {
00179 protected:
00180         CMapping<double> **mappings;
00181         int numMappings;
00182 
00183 
00184 
00185         double doGetOutputValue(ColumnVector *inputVector);
00186 public:
00187         bool deleteMappings;
00188 
00189         CRegressionMultiMapping(int numMappings, int numDimensions);
00190         virtual ~CRegressionMultiMapping();
00191 
00192         void addMapping(int index, CMapping<double> *mapping);
00193 };
00194 
00195 #endif