Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

crbftrees.h

Go to the documentation of this file.
00001 //
00002 // C++ Interface: crbftrees
00003 //
00004 // Description: 
00005 //
00006 //
00007 // Author: Neumann Gerhard <gerhard@tu-graz.ac.at>, (C) 2006
00008 //
00009 // Copyright: See COPYING file that comes with this distribution
00010 //
00011 //
00012 
00013 #ifndef C_RBFTREES__H
00014 #define C_RBFTREES__H
00015 
00016 
00017 
00018 #include "cinputdata.h"
00019 #include "ctrees.h"
00020 #include "cforest.h"
00021 #include "cextratrees.h"
00022 #include "cnearestneighbor.h"
00023 
00024 #include <newmat/newmat.h>
00025 #include <list>
00026 #include <vector>
00027 #include <stdio.h>
00028 
00029 class CRBFBasisFunction 
00030 {
00031         protected:
00032                 ColumnVector *center;
00033                 ColumnVector *sigma;
00034         public:
00035                 CRBFBasisFunction(ColumnVector *center, ColumnVector *sigma);
00036 
00037 
00038                 double getActivationFactor(ColumnVector *x);
00039                 
00040                 ColumnVector *getCenter();
00041                 ColumnVector *getSigma();
00042 
00043                 void setSigma(ColumnVector *sigma);
00044                 void setCenter(ColumnVector *center);
00045 };
00046 
00047 class CRBFBasisFunctionLinearWeight : public CRBFBasisFunction 
00048 {
00049         protected:
00050                 double weight;
00051         public:
00052                 CRBFBasisFunctionLinearWeight(ColumnVector *center, ColumnVector *sigma, double weight);
00053 
00054 
00055                 double getOutputWeight();
00056                 void setWeight(double weight);
00057 };
00058 
00059 class CRBFDataFactory : public CTreeDataFactory<CRBFBasisFunction *>
00060 {
00061 protected:
00062         ColumnVector *minVar;
00063         ColumnVector *varMultiplier;
00064 
00065         CDataSet *inputData;
00066 public:
00067         CRBFDataFactory(CDataSet *inputData, ColumnVector *varMultiplier, ColumnVector *minVar);
00068         CRBFDataFactory(CDataSet *inputData);
00069         virtual ~CRBFDataFactory();
00070 
00071 
00072         virtual CRBFBasisFunction *createTreeData(DataSubset *dataSubset, int numLeaves);
00073         virtual void deleteData(CRBFBasisFunction *basisFunction); 
00074 };
00075 
00076 class CRBFLinearWeightDataFactory : public CTreeDataFactory<CRBFBasisFunctionLinearWeight *>
00077 {
00078 protected:
00079         ColumnVector *minVar;
00080         ColumnVector *varMultiplier;
00081 
00082         CDataSet *inputData;
00083         CDataSet1D *outputData;
00084 public:
00085         CRBFLinearWeightDataFactory(CDataSet *inputData, CDataSet1D *outputData, ColumnVector *varMultiplier, ColumnVector *minVar);
00086         virtual ~CRBFLinearWeightDataFactory();
00087 
00088 
00089         virtual CRBFBasisFunctionLinearWeight *createTreeData(DataSubset *dataSubset, int numLeaves);
00090         virtual void deleteData(CRBFBasisFunctionLinearWeight *basisFunction); 
00091 };
00092 
00093 class CRBFExtraRegressionTree : public CExtraTree<CRBFBasisFunctionLinearWeight *>
00094 {
00095         public:
00096                 CRBFExtraRegressionTree(CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned  int n_min, double treshold, ColumnVector *varMultiplier, ColumnVector *minVar);
00097                 virtual ~CRBFExtraRegressionTree();
00098 };
00099 
00100 class CKNearestRBFCenters : public  CKNearestNeighborsTreeData<int, CRBFBasisFunctionLinearWeight *>
00101 {
00102         protected:
00103                 ColumnVector *buffVector;
00104 
00105                 virtual void addDataElements(ColumnVector *point, CLeaf<CRBFBasisFunctionLinearWeight *> *leaf, CKDRectangle *rectangle);
00106                 
00107         public:
00108                 CKNearestRBFCenters(CTree<CRBFBasisFunctionLinearWeight *> *tree, int K);
00109                 virtual ~CKNearestRBFCenters();
00110 };
00111 
00112 class CRBFRegressionTreeOutputMapping : public CMapping<double>
00113 {
00114         protected:
00115                 CTree<CRBFBasisFunctionLinearWeight *> *tree;
00116                 CKNearestRBFCenters *nearestLeaves;
00117 
00118                 double doGetOutputValue(ColumnVector *output);
00119         public:
00120                 CRBFRegressionTreeOutputMapping(CTree<CRBFBasisFunctionLinearWeight *> *tree, int K);
00121                 virtual ~CRBFRegressionTreeOutputMapping();
00122 
00123                 
00124 };
00125 
00126 class CRBFExtraRegressionForest : public CForest<CRBFBasisFunctionLinearWeight *>, public CMapping<double>
00127 {
00128 protected:
00129         CRBFRegressionTreeOutputMapping **mapping;
00130         CTreeDataFactory<CRBFBasisFunctionLinearWeight *> *dataFactory;
00131 
00132         void initRBFMapping(int kNN);
00133 
00134         virtual double doGetOutputValue(ColumnVector *input);
00135 public:
00136         CRBFExtraRegressionForest(int numTrees, int kNN, CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned  int n_min, double treshold, ColumnVector *varMultiplier, ColumnVector *minVar);
00137         virtual ~CRBFExtraRegressionForest();
00138 
00139         
00140 };
00141 
00142 class CRBFLinearWeightForest : public CForest<CRBFBasisFunctionLinearWeight *>, public CMapping<double>
00143 {
00144         public:
00145 
00146         protected:
00147                 CRBFLinearWeightForest(int numTrees, int numDim);
00148                 virtual ~CRBFLinearWeightForest();
00149 
00150                 double getOutputValue(ColumnVector *inputData);
00151 
00152                 virtual void saveASCII(FILE *stream);
00153 
00154 };
00155 
00156 
00157 class CExtraTreeRBFLinearWeightForest : public CRBFLinearWeightForest
00158 {
00159         protected:
00160                 CRBFLinearWeightDataFactory *dataFactory;
00161 
00162         public:
00163 
00164                 CExtraTreeRBFLinearWeightForest(int numTrees, CDataSet *inputData, CDataSet1D *outputData, unsigned int K,unsigned  int n_min, double treshold, ColumnVector *varMultiplier, ColumnVector *minVar);
00165                 virtual ~CExtraTreeRBFLinearWeightForest();
00166 };
00167 
00168 
00169 #endif