00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #ifndef C_RBFTREES__H
00014 #define C_RBFTREES__H
00015
00016
00017
00018 #include "cinputdata.h"
00019 #include "ctrees.h"
00020 #include "cforest.h"
00021 #include "cextratrees.h"
00022 #include "cnearestneighbor.h"
00023
00024 #include <newmat/newmat.h>
00025 #include <list>
00026 #include <vector>
00027 #include <stdio.h>
00028
00029 class CRBFBasisFunction
00030 {
00031 protected:
00032 ColumnVector *center;
00033 ColumnVector *sigma;
00034 public:
00035 CRBFBasisFunction(ColumnVector *center, ColumnVector *sigma);
00036
00037
00038 double getActivationFactor(ColumnVector *x);
00039
00040 ColumnVector *getCenter();
00041 ColumnVector *getSigma();
00042
00043 void setSigma(ColumnVector *sigma);
00044 void setCenter(ColumnVector *center);
00045 };
00046
00047 class CRBFBasisFunctionLinearWeight : public CRBFBasisFunction
00048 {
00049 protected:
00050 double weight;
00051 public:
00052 CRBFBasisFunctionLinearWeight(ColumnVector *center, ColumnVector *sigma, double weight);
00053
00054
00055 double getOutputWeight();
00056 void setWeight(double weight);
00057 };
00058
00059 class CRBFDataFactory : public CTreeDataFactory<CRBFBasisFunction *>
00060 {
00061 protected:
00062 ColumnVector *minVar;
00063 ColumnVector *varMultiplier;
00064
00065 CDataSet *inputData;
00066 public:
00067 CRBFDataFactory(CDataSet *inputData, ColumnVector *varMultiplier, ColumnVector *minVar);
00068 CRBFDataFactory(CDataSet *inputData);
00069 virtual ~CRBFDataFactory();
00070
00071
00072 virtual CRBFBasisFunction *createTreeData(DataSubset *dataSubset, int numLeaves);
00073 virtual void deleteData(CRBFBasisFunction *basisFunction);
00074 };
00075
00076 class CRBFLinearWeightDataFactory : public CTreeDataFactory<CRBFBasisFunctionLinearWeight *>
00077 {
00078 protected:
00079 ColumnVector *minVar;
00080 ColumnVector *varMultiplier;
00081
00082 CDataSet *inputData;
00083 CDataSet1D *outputData;
00084 public:
00085 CRBFLinearWeightDataFactory(CDataSet *inputData, CDataSet1D *outputData, ColumnVector *varMultiplier, ColumnVector *minVar);
00086 virtual ~CRBFLinearWeightDataFactory();
00087
00088
00089 virtual CRBFBasisFunctionLinearWeight *createTreeData(DataSubset *dataSubset, int numLeaves);
00090 virtual void deleteData(CRBFBasisFunctionLinearWeight *basisFunction);
00091 };
00092
00093 class CRBFExtraRegressionTree : public CExtraTree<CRBFBasisFunctionLinearWeight *>
00094 {
00095 public:
00096 CRBFExtraRegressionTree(CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned int n_min, double treshold, ColumnVector *varMultiplier, ColumnVector *minVar);
00097 virtual ~CRBFExtraRegressionTree();
00098 };
00099
00100 class CKNearestRBFCenters : public CKNearestNeighborsTreeData<int, CRBFBasisFunctionLinearWeight *>
00101 {
00102 protected:
00103 ColumnVector *buffVector;
00104
00105 virtual void addDataElements(ColumnVector *point, CLeaf<CRBFBasisFunctionLinearWeight *> *leaf, CKDRectangle *rectangle);
00106
00107 public:
00108 CKNearestRBFCenters(CTree<CRBFBasisFunctionLinearWeight *> *tree, int K);
00109 virtual ~CKNearestRBFCenters();
00110 };
00111
00112 class CRBFRegressionTreeOutputMapping : public CMapping<double>
00113 {
00114 protected:
00115 CTree<CRBFBasisFunctionLinearWeight *> *tree;
00116 CKNearestRBFCenters *nearestLeaves;
00117
00118 double doGetOutputValue(ColumnVector *output);
00119 public:
00120 CRBFRegressionTreeOutputMapping(CTree<CRBFBasisFunctionLinearWeight *> *tree, int K);
00121 virtual ~CRBFRegressionTreeOutputMapping();
00122
00123
00124 };
00125
00126 class CRBFExtraRegressionForest : public CForest<CRBFBasisFunctionLinearWeight *>, public CMapping<double>
00127 {
00128 protected:
00129 CRBFRegressionTreeOutputMapping **mapping;
00130 CTreeDataFactory<CRBFBasisFunctionLinearWeight *> *dataFactory;
00131
00132 void initRBFMapping(int kNN);
00133
00134 virtual double doGetOutputValue(ColumnVector *input);
00135 public:
00136 CRBFExtraRegressionForest(int numTrees, int kNN, CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned int n_min, double treshold, ColumnVector *varMultiplier, ColumnVector *minVar);
00137 virtual ~CRBFExtraRegressionForest();
00138
00139
00140 };
00141
00142 class CRBFLinearWeightForest : public CForest<CRBFBasisFunctionLinearWeight *>, public CMapping<double>
00143 {
00144 public:
00145
00146 protected:
00147 CRBFLinearWeightForest(int numTrees, int numDim);
00148 virtual ~CRBFLinearWeightForest();
00149
00150 double getOutputValue(ColumnVector *inputData);
00151
00152 virtual void saveASCII(FILE *stream);
00153
00154 };
00155
00156
00157 class CExtraTreeRBFLinearWeightForest : public CRBFLinearWeightForest
00158 {
00159 protected:
00160 CRBFLinearWeightDataFactory *dataFactory;
00161
00162 public:
00163
00164 CExtraTreeRBFLinearWeightForest(int numTrees, CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned int n_min, double treshold, ColumnVector *varMultiplier, ColumnVector *minVar);
00165 virtual ~CExtraTreeRBFLinearWeightForest();
00166 };
00167
00168
00169 #endif