Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

ctrees.h

Go to the documentation of this file.
00001 #ifndef CTREES__H
00002 #define CTREES__H
00003 
00004 #include "cinputdata.h"
00005 
00006 #include <newmat/newmat.h>
00007 #include <list>
00008 #include <vector>
00009 #include <stdio.h>
00010 
00011 
00012 
00013 
00014 class CSplittingCondition
00015 {
00016         public:
00017                 virtual ~CSplittingCondition() {};
00018                 virtual bool isLeftNode(ColumnVector *input) = 0;
00019 };
00020 
00021 class C1DSplittingCondition : public CSplittingCondition
00022 {
00023         protected:
00024                 int dimension;
00025                 double treshold;
00026         public:
00027                 C1DSplittingCondition(int dimension, double treshold);
00028                 virtual ~C1DSplittingCondition();
00029                 
00030                 virtual bool isLeftNode(ColumnVector *input);
00031 
00032                 int getDimension();
00033                 double getTreshold();
00034 };
00035 
00036 class CSplittingConditionFactory 
00037 {
00038         protected:
00039 
00040         public:
00041                 virtual ~CSplittingConditionFactory() {};
00042                 virtual CSplittingCondition *createSplittingCondition(DataSubset *dataSubset) = 0;
00043 
00044                 virtual bool isLeaf(DataSubset *dataSubset) = 0;
00045 };
00046 
00047 
00048 
00049 template<typename TreeData> class CTreeDataFactory
00050 {
00051         public:
00052                 virtual ~CTreeDataFactory() {};
00053                 
00054                 virtual TreeData createTreeData(DataSubset *dataSubset, int numLeaves) = 0;
00055                 virtual void deleteData(TreeData ){}; 
00056 };
00057                 
00058 
00059 template<typename TreeData> class CLeaf;
00060 
00061 template<typename TreeData> class CTreeElement
00062 {
00063         protected:
00064                 CTreeElement *parent;
00065 
00066         public:
00067                 CTreeElement(CTreeElement<TreeData> *l_parent)
00068                 {
00069                         parent = l_parent;
00070                 };
00071                 virtual ~CTreeElement() {};
00072 
00073                 virtual CLeaf<TreeData> *getLeaf(ColumnVector *input) = 0;
00074                 
00075                 CTreeElement<TreeData> *getParent() {return parent;};
00076 
00077                 virtual int getDepth() {return 0;};
00078 
00079                 virtual bool isLeaf() {return false;};
00080 };
00081 
00082 template<typename TreeData> class CNode : public CTreeElement<TreeData>
00083 {
00084         protected:
00085                 CSplittingCondition *split;
00086                 CTreeElement<TreeData> *leftElement;
00087                 CTreeElement<TreeData> *rightElement;
00088         public:
00089                 CNode(CTreeElement<TreeData> *parent, CSplittingCondition *l_condition, CTreeElement<TreeData> *l_leftElement, CTreeElement<TreeData> *l_rightElement);
00090                 virtual ~CNode();
00091         
00092                 virtual CLeaf<TreeData> *getLeaf(ColumnVector *input);
00093                 virtual int getDepth();
00094 
00095                 CTreeElement<TreeData> *getLeftElement() {return leftElement;};
00096                 CTreeElement<TreeData> *getRightElement() {return rightElement;};
00097 
00098                 void setLeftElement(CTreeElement<TreeData> *l_leftElement) {leftElement = l_leftElement;};
00099                 void setRightElement(CTreeElement<TreeData> *l_rightElement) {rightElement = l_rightElement;};
00100 
00101 
00102                 CSplittingCondition *getSplittingCondition();
00103 };
00104 
00105 template<typename TreeData> CNode<TreeData>::CNode(CTreeElement<TreeData> *parent, CSplittingCondition *l_condition, CTreeElement<TreeData> *l_leftElement, CTreeElement<TreeData> *l_rightElement) : CTreeElement<TreeData>(parent)
00106 {
00107         split = l_condition;
00108         leftElement = l_leftElement;
00109         rightElement = l_rightElement;
00110 };
00111         
00112 template<typename TreeData> CNode<TreeData>::~CNode()
00113 {
00114         delete split;
00115         delete rightElement;
00116         delete leftElement;
00117 }
00118 
00119 template<typename TreeData> CSplittingCondition *CNode<TreeData>::getSplittingCondition()
00120 {
00121         return split;
00122 }
00123         
00124 template<typename TreeData> CLeaf<TreeData> *CNode<TreeData>::getLeaf(ColumnVector *input)
00125 {
00126         bool isLeft = split->isLeftNode(input);
00127         if (isLeft)
00128         {
00129                 return leftElement->getLeaf(input);
00130         }
00131         else
00132         {
00133                 return rightElement->getLeaf(input);
00134         }
00135 }
00136 
00137 template<typename TreeData> int CNode<TreeData>::getDepth()
00138 {
00139         return std::max(leftElement->getDepth(), rightElement->getDepth()) + 1;
00140 }
00141 
00142 template <typename TreeData> class CLeaf : public CTreeElement<TreeData>
00143 {
00144         protected:
00145                 TreeData data;
00146                 CTreeDataFactory<TreeData> *dataFactory;
00147                 
00148                 int numLeaf;
00149                 DataSubset *subset;
00150         public:
00151                 CLeaf(CTreeElement<TreeData> *parent, TreeData l_data, DataSubset *subset, int numLeaf, CTreeDataFactory<TreeData> *l_dataFactory);
00152 ;
00153                 virtual ~CLeaf();
00154                                 
00155                 virtual CLeaf<TreeData> *getLeaf(ColumnVector *);
00156                 virtual TreeData getTreeData();
00157 
00158                 virtual bool isLeaf() {return true;};
00159 
00160                 int getLeafNumber() {return numLeaf;};
00161                 int getNumSamples() {return subset->size();};
00162 
00163                 DataSubset *getDataSet() {return subset;};
00164 };
00165 
00166 template <typename TreeData> CLeaf<TreeData>::CLeaf(CTreeElement<TreeData> *parent, TreeData l_data, DataSubset *l_subset, int l_numLeaf, CTreeDataFactory<TreeData> *l_dataFactory) : CTreeElement<TreeData>(parent)
00167 {
00168         data = l_data;
00169         dataFactory = l_dataFactory;
00170         numLeaf = l_numLeaf;
00171         subset = new DataSubset(*l_subset);
00172 };
00173                 
00174 template <typename TreeData> CLeaf<TreeData>::~CLeaf()
00175 {
00176         dataFactory->deleteData(data);
00177         delete subset;
00178 };
00179                                 
00180 template <typename TreeData> TreeData CLeaf<TreeData>::getTreeData()
00181 {
00182         //printf("Returning Data from Leave %d\n", numLeaf);
00183         return data;
00184 }
00185 
00186 template <typename TreeData> CLeaf<TreeData> *CLeaf<TreeData>::getLeaf(ColumnVector *)
00187 {
00188         //printf("Returning Data from Leave %d\n", numLeaf);
00189         return this;
00190 }
00191 
00192 
00193 template<typename TreeData> class CTree : public CMapping<TreeData>
00194 {
00195         protected:
00196                 CTreeElement<TreeData> *root;
00197                 
00198                 CTreeDataFactory<TreeData> *dataFactory;
00199                 
00200                 int numLeaves;
00201                 
00202                 virtual void createTree(CDataSet *inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory, bool createLeaves = true);
00203                 
00204                 
00205                 virtual CTreeElement<TreeData> *createNode(CTreeElement<TreeData> *parent, CDataSet *inputData, DataSubset *inputDataSubset, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *dataFactory);
00206         
00207                 CTree(int numDim);
00208                 CLeaf<TreeData> **leaves;
00209 
00210                 virtual int setLeaves(CTreeElement<TreeData> *element, int numLeaf);
00211 
00212                 CDataSet *inputData;
00213 
00214                 virtual TreeData doGetOutputValue(ColumnVector *input);
00215         public: 
00216                 
00217                 
00218                 CTree(CDataSet *inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory);
00219                 
00220                 
00221                 virtual ~CTree();
00222                 
00223                 CLeaf<TreeData> *getLeaf(int index);            
00224 
00225                 
00226                 int getNumLeaves();
00227                 int getDepth();
00228                 int getNumSamples();
00229 
00230                 CTreeElement<TreeData> *getRoot() {return root;};
00231 
00232                 CDataSet *getInputData() {return inputData;};
00233 
00234                 virtual CLeaf<TreeData> *getLeaf(ColumnVector *input);  
00235                 CTreeDataFactory<TreeData> *getDataFactory() {return dataFactory;};     
00236         
00237                 virtual void addNewInput(int index, CSplittingConditionFactory *splitting);
00238 
00239                 void createLeavesArray();
00240 };
00241 
00242 
00243 template<typename TreeData> CTree<TreeData>::CTree(int numDim) : CMapping<TreeData>(numDim)
00244 {
00245         root = NULL;
00246         dataFactory = NULL;
00247 
00248         leaves = NULL;
00249 };
00250 
00251 
00252 template<typename TreeData> CTree<TreeData>::~CTree()
00253 {
00254         if (root != NULL)
00255         {
00256                 delete root;
00257         }
00258         if (leaves != NULL)
00259         {
00260                 delete [] leaves;
00261         }
00262 };
00263 
00264 template<typename TreeData> CLeaf<TreeData> * CTree<TreeData>::getLeaf(int index)
00265 {
00266         return leaves[index];
00267 }
00268 
00269 template<typename TreeData> void CTree<TreeData>::addNewInput(int index, CSplittingConditionFactory *splittingFactory)
00270 {
00271         ColumnVector *newInput = (*inputData)[index];
00272         CLeaf<TreeData> *leaf = getLeaf(newInput);
00273 
00274         if (leaf != NULL && leaf->getParent() != NULL)
00275         {
00276                 CNode<TreeData> *parent = dynamic_cast<CNode<TreeData> *>(leaf->getParent());
00277                 
00278                 bool isLeft = parent->getLeftElement() == leaf;
00279                 DataSubset *subset = leaf->getDataSet();
00280                 subset->insert(index);          
00281 
00282                 CTreeElement<TreeData> *newNode = createNode(parent, getInputData(), subset, splittingFactory, getDataFactory()); 
00283                 
00284                 if (isLeft)
00285                 {
00286                         parent->setLeftElement(newNode);
00287                 }       
00288                 else
00289                 {
00290                         parent->setRightElement(newNode);
00291                 }
00292                 delete leaf;
00293                 numLeaves --;  
00294 
00295                 
00296         }
00297         else
00298         {
00299                 if (root != NULL)
00300                 {
00301                         delete root;
00302                 }
00303 
00304                 createTree(getInputData(), splittingFactory, getDataFactory(), false);
00305         }
00306         //printf("DataSetSize %d, Leaves: %d\n", inputData->size(), numLeaves);
00307 }
00308 
00309 template<typename TreeData> int CTree<TreeData>::getNumLeaves()
00310 {
00311         return numLeaves;
00312 }
00313 
00314 template<typename TreeData> int CTree<TreeData>::getDepth()
00315 {
00316         return root->getDepth();
00317 }
00318                 
00319 template<typename TreeData> CTree<TreeData>::CTree(CDataSet *inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory) : CMapping<TreeData>(inputData->getNumDimensions())
00320 {
00321         createTree(inputData, splittingFactory, l_dataFactory);
00322 };
00323                 
00324                 
00325 template<typename TreeData>  TreeData CTree<TreeData>::doGetOutputValue(ColumnVector *input)
00326 {
00327         return root->getLeaf(input)->getTreeData();
00328 };
00329 
00330 template<typename TreeData>  CLeaf<TreeData> * CTree<TreeData>::getLeaf(ColumnVector *input)
00331 {
00332         if (root)
00333         {
00334                 ColumnVector *l_input = CMapping<TreeData>::getPreprocessedInput(input);
00335         
00336                 return root->getLeaf(l_input);
00337         }
00338         else
00339         {
00340                 return NULL;
00341         }
00342 };
00343 
00344 
00345 template<typename TreeData> void CTree<TreeData>::createTree(CDataSet *l_inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory, bool setLeaves)
00346 {
00347         inputData = l_inputData;
00348         dataFactory = l_dataFactory;
00349 
00350         DataSubset subset;
00351         numLeaves = 0;
00352         
00353         if (l_inputData->size() > 0)
00354         { 
00355                 for (unsigned int i = 0; i < l_inputData->size(); i ++)
00356                 {
00357                         subset.insert(i);
00358                 }
00359                 dataFactory = l_dataFactory;
00360                 root = createNode(NULL, l_inputData, &subset, splittingFactory, l_dataFactory);
00361                 
00362                 if (setLeaves)
00363                 {
00364                         createLeavesArray();
00365                 }
00366         }
00367 }
00368 
00369 template<typename TreeData> void CTree<TreeData>::createLeavesArray()
00370 {
00371         if (leaves)
00372         {
00373                 delete leaves;
00374         }
00375 
00376         leaves = new CLeaf<TreeData>*[numLeaves];
00377         setLeaves(root, 0);     
00378 }
00379 
00380 template<typename TreeData> int CTree<TreeData>::getNumSamples()
00381 {
00382         int samples = 0;
00383         for (int i = 0; i < numLeaves; i ++)
00384         {
00385                 samples += leaves[i]->getDataSet()->size();
00386         }
00387         return samples;
00388 }
00389 
00390 
00391 
00392 template<typename TreeData> int CTree<TreeData>::setLeaves(CTreeElement<TreeData> *element, int numLeaf)
00393 {
00394         if (element->isLeaf())
00395         {
00396                 CLeaf<TreeData> *leaf = dynamic_cast<CLeaf<TreeData> *>(element);
00397                 leaves[numLeaf] = leaf;
00398                 return 1;
00399         }
00400         else
00401         {
00402                 int newLeaves = 0;
00403                 CNode<TreeData> *node = dynamic_cast<CNode<TreeData> *>(element);
00404                 
00405                 newLeaves += setLeaves(node->getLeftElement(), numLeaf);
00406                 
00407                 newLeaves += setLeaves(node->getRightElement(), numLeaf + newLeaves);
00408                 
00409                 return newLeaves;
00410         }
00411         
00412 }
00413 
00414 template<typename TreeData>  CTreeElement<TreeData> *CTree<TreeData>::createNode(CTreeElement<TreeData> *parent,CDataSet *inputData, DataSubset *inputDataSubset, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *dataFactory)
00415 {
00416         //printf("Creating Node with %d Inputs\n", inputDataSubset->size());
00417         
00418         bool isLeaf = splittingFactory->isLeaf(inputDataSubset);
00419         CTreeElement<TreeData> *newNode = NULL;
00420                         
00421         if (isLeaf)
00422         {
00423                 TreeData data = dataFactory->createTreeData(inputDataSubset, numLeaves);
00424                 newNode = new CLeaf<TreeData>(parent, data,inputDataSubset, numLeaves, dataFactory);
00425                 numLeaves ++;
00426         }
00427         else
00428         {
00429                 CSplittingCondition *split = splittingFactory->createSplittingCondition(inputDataSubset);
00430                                 
00431                 DataSubset leftSet;
00432                 DataSubset rightSet;
00433                                 
00434                 DataSubset::iterator it = inputDataSubset->begin();
00435                                 
00436                 for (; it != inputDataSubset->end(); it++)
00437                 {
00438                         ColumnVector *input = (*inputData)[*it];
00439                                         
00440                         if (split->isLeftNode(input))
00441                         {
00442                                 leftSet.insert(*it);
00443                         }
00444                         else
00445                         {
00446                                 rightSet.insert(*it);
00447                         }
00448                 }
00449                 //printf("LeftSet : %d, Right Set : %d\n", leftSet.size(), rightSet.size());
00450                 
00451                 CNode<TreeData> *newNode1 = new CNode<TreeData>(parent, split, NULL, NULL);
00452 
00453                 CTreeElement<TreeData> *leftElement = createNode(newNode1, inputData, &leftSet, splittingFactory, dataFactory);
00454                                 
00455                 CTreeElement<TreeData> *rightElement = createNode(newNode1, inputData, &rightSet, splittingFactory, dataFactory);
00456                                 
00457                 newNode1->setLeftElement(leftElement);
00458                 newNode1->setRightElement(rightElement);
00459 
00460                 newNode = newNode1;
00461                 
00462         }
00463         return newNode;
00464 }
00465 
00466 
00467 #endif
00468