00001 #ifndef C_EXTRA_TREES__H
00002 #define C_EXTRA_TREES__H
00003
00004 #include "ctrees.h"
00005
00006 #include <newmat/newmat.h>
00007
00008 class CDataSet;
00009 class CDataSet1D;
00010
00011 class CExtraTreesSplittingConditionFactory : public CSplittingConditionFactory
00012 {
00013 protected:
00014 unsigned int K;
00015 unsigned int n_min;
00016
00017 double outTreshold;
00018
00019 CDataSet *inputData;
00020 CDataSet1D *outputData;
00021 CDataSet1D *weightingData;
00022
00023 double getScore(CSplittingCondition *condition, DataSubset *dataSubset);
00024 public:
00025 CExtraTreesSplittingConditionFactory(CDataSet *inputData, CDataSet1D *outputData, unsigned int K, unsigned int n_min, double outTresh = 0.0, CDataSet1D *weightingData = NULL);
00026 virtual ~CExtraTreesSplittingConditionFactory();
00027
00028 virtual CSplittingCondition *createSplittingCondition(DataSubset *dataSubset);
00029
00030 virtual bool isLeaf(DataSubset *dataSubset);
00031 };
00032
00033 template <typename TreeData> class CExtraTree : public CTree<TreeData>
00034 {
00035 public:
00036 CExtraTree(CDataSet *inputData, CDataSet1D *outputData, CTreeDataFactory<TreeData> *dataFactory, unsigned int K,unsigned int n_min, double outTresh, CDataSet1D *weightingData = NULL) : CTree<TreeData>(inputData->getNumDimensions())
00037 {
00038 CSplittingConditionFactory *splittingFactory = new CExtraTreesSplittingConditionFactory(inputData, outputData, K, n_min, outTresh, weightingData);
00039 CTree<TreeData>::createTree(inputData, splittingFactory, dataFactory);
00040 delete splittingFactory;
00041 };
00042
00043 virtual ~CExtraTree()
00044 {
00045
00046 };
00047 };
00048
00049 class CExtraRegressionTree : public CExtraTree<double>
00050 {
00051 public:
00052 CExtraRegressionTree(CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned int n_min, double treshold, CDataSet1D *weightingData = NULL);
00053 virtual ~CExtraRegressionTree();
00054 };
00055
00056
00057
00058
00059
00060 #endif