00001 #ifndef CTREES__H
00002 #define CTREES__H
00003
00004 #include "cinputdata.h"
00005
00006 #include <newmat/newmat.h>
00007 #include <list>
00008 #include <vector>
00009 #include <stdio.h>
00010
00011
00012
00013
00014 class CSplittingCondition
00015 {
00016 public:
00017 virtual ~CSplittingCondition() {};
00018 virtual bool isLeftNode(ColumnVector *input) = 0;
00019 };
00020
00021 class C1DSplittingCondition : public CSplittingCondition
00022 {
00023 protected:
00024 int dimension;
00025 double treshold;
00026 public:
00027 C1DSplittingCondition(int dimension, double treshold);
00028 virtual ~C1DSplittingCondition();
00029
00030 virtual bool isLeftNode(ColumnVector *input);
00031
00032 int getDimension();
00033 double getTreshold();
00034 };
00035
00036 class CSplittingConditionFactory
00037 {
00038 protected:
00039
00040 public:
00041 virtual ~CSplittingConditionFactory() {};
00042 virtual CSplittingCondition *createSplittingCondition(DataSubset *dataSubset) = 0;
00043
00044 virtual bool isLeaf(DataSubset *dataSubset) = 0;
00045 };
00046
00047
00048
00049 template<typename TreeData> class CTreeDataFactory
00050 {
00051 public:
00052 virtual ~CTreeDataFactory() {};
00053
00054 virtual TreeData createTreeData(DataSubset *dataSubset, int numLeaves) = 0;
00055 virtual void deleteData(TreeData ){};
00056 };
00057
00058
00059 template<typename TreeData> class CLeaf;
00060
00061 template<typename TreeData> class CTreeElement
00062 {
00063 protected:
00064 CTreeElement *parent;
00065
00066 public:
00067 CTreeElement(CTreeElement<TreeData> *l_parent)
00068 {
00069 parent = l_parent;
00070 };
00071 virtual ~CTreeElement() {};
00072
00073 virtual CLeaf<TreeData> *getLeaf(ColumnVector *input) = 0;
00074
00075 CTreeElement<TreeData> *getParent() {return parent;};
00076
00077 virtual int getDepth() {return 0;};
00078
00079 virtual bool isLeaf() {return false;};
00080 };
00081
00082 template<typename TreeData> class CNode : public CTreeElement<TreeData>
00083 {
00084 protected:
00085 CSplittingCondition *split;
00086 CTreeElement<TreeData> *leftElement;
00087 CTreeElement<TreeData> *rightElement;
00088 public:
00089 CNode(CTreeElement<TreeData> *parent, CSplittingCondition *l_condition, CTreeElement<TreeData> *l_leftElement, CTreeElement<TreeData> *l_rightElement);
00090 virtual ~CNode();
00091
00092 virtual CLeaf<TreeData> *getLeaf(ColumnVector *input);
00093 virtual int getDepth();
00094
00095 CTreeElement<TreeData> *getLeftElement() {return leftElement;};
00096 CTreeElement<TreeData> *getRightElement() {return rightElement;};
00097
00098 void setLeftElement(CTreeElement<TreeData> *l_leftElement) {leftElement = l_leftElement;};
00099 void setRightElement(CTreeElement<TreeData> *l_rightElement) {rightElement = l_rightElement;};
00100
00101
00102 CSplittingCondition *getSplittingCondition();
00103 };
00104
00105 template<typename TreeData> CNode<TreeData>::CNode(CTreeElement<TreeData> *parent, CSplittingCondition *l_condition, CTreeElement<TreeData> *l_leftElement, CTreeElement<TreeData> *l_rightElement) : CTreeElement<TreeData>(parent)
00106 {
00107 split = l_condition;
00108 leftElement = l_leftElement;
00109 rightElement = l_rightElement;
00110 };
00111
00112 template<typename TreeData> CNode<TreeData>::~CNode()
00113 {
00114 delete split;
00115 delete rightElement;
00116 delete leftElement;
00117 }
00118
00119 template<typename TreeData> CSplittingCondition *CNode<TreeData>::getSplittingCondition()
00120 {
00121 return split;
00122 }
00123
00124 template<typename TreeData> CLeaf<TreeData> *CNode<TreeData>::getLeaf(ColumnVector *input)
00125 {
00126 bool isLeft = split->isLeftNode(input);
00127 if (isLeft)
00128 {
00129 return leftElement->getLeaf(input);
00130 }
00131 else
00132 {
00133 return rightElement->getLeaf(input);
00134 }
00135 }
00136
00137 template<typename TreeData> int CNode<TreeData>::getDepth()
00138 {
00139 return std::max(leftElement->getDepth(), rightElement->getDepth()) + 1;
00140 }
00141
00142 template <typename TreeData> class CLeaf : public CTreeElement<TreeData>
00143 {
00144 protected:
00145 TreeData data;
00146 CTreeDataFactory<TreeData> *dataFactory;
00147
00148 int numLeaf;
00149 DataSubset *subset;
00150 public:
00151 CLeaf(CTreeElement<TreeData> *parent, TreeData l_data, DataSubset *subset, int numLeaf, CTreeDataFactory<TreeData> *l_dataFactory);
00152 ;
00153 virtual ~CLeaf();
00154
00155 virtual CLeaf<TreeData> *getLeaf(ColumnVector *);
00156 virtual TreeData getTreeData();
00157
00158 virtual bool isLeaf() {return true;};
00159
00160 int getLeafNumber() {return numLeaf;};
00161 int getNumSamples() {return subset->size();};
00162
00163 DataSubset *getDataSet() {return subset;};
00164 };
00165
00166 template <typename TreeData> CLeaf<TreeData>::CLeaf(CTreeElement<TreeData> *parent, TreeData l_data, DataSubset *l_subset, int l_numLeaf, CTreeDataFactory<TreeData> *l_dataFactory) : CTreeElement<TreeData>(parent)
00167 {
00168 data = l_data;
00169 dataFactory = l_dataFactory;
00170 numLeaf = l_numLeaf;
00171 subset = new DataSubset(*l_subset);
00172 };
00173
00174 template <typename TreeData> CLeaf<TreeData>::~CLeaf()
00175 {
00176 dataFactory->deleteData(data);
00177 delete subset;
00178 };
00179
00180 template <typename TreeData> TreeData CLeaf<TreeData>::getTreeData()
00181 {
00182
00183 return data;
00184 }
00185
00186 template <typename TreeData> CLeaf<TreeData> *CLeaf<TreeData>::getLeaf(ColumnVector *)
00187 {
00188
00189 return this;
00190 }
00191
00192
00193 template<typename TreeData> class CTree : public CMapping<TreeData>
00194 {
00195 protected:
00196 CTreeElement<TreeData> *root;
00197
00198 CTreeDataFactory<TreeData> *dataFactory;
00199
00200 int numLeaves;
00201
00202 virtual void createTree(CDataSet *inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory, bool createLeaves = true);
00203
00204
00205 virtual CTreeElement<TreeData> *createNode(CTreeElement<TreeData> *parent, CDataSet *inputData, DataSubset *inputDataSubset, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *dataFactory);
00206
00207 CTree(int numDim);
00208 CLeaf<TreeData> **leaves;
00209
00210 virtual int setLeaves(CTreeElement<TreeData> *element, int numLeaf);
00211
00212 CDataSet *inputData;
00213
00214 virtual TreeData doGetOutputValue(ColumnVector *input);
00215 public:
00216
00217
00218 CTree(CDataSet *inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory);
00219
00220
00221 virtual ~CTree();
00222
00223 CLeaf<TreeData> *getLeaf(int index);
00224
00225
00226 int getNumLeaves();
00227 int getDepth();
00228 int getNumSamples();
00229
00230 CTreeElement<TreeData> *getRoot() {return root;};
00231
00232 CDataSet *getInputData() {return inputData;};
00233
00234 virtual CLeaf<TreeData> *getLeaf(ColumnVector *input);
00235 CTreeDataFactory<TreeData> *getDataFactory() {return dataFactory;};
00236
00237 virtual void addNewInput(int index, CSplittingConditionFactory *splitting);
00238
00239 void createLeavesArray();
00240 };
00241
00242
00243 template<typename TreeData> CTree<TreeData>::CTree(int numDim) : CMapping<TreeData>(numDim)
00244 {
00245 root = NULL;
00246 dataFactory = NULL;
00247
00248 leaves = NULL;
00249 };
00250
00251
00252 template<typename TreeData> CTree<TreeData>::~CTree()
00253 {
00254 if (root != NULL)
00255 {
00256 delete root;
00257 }
00258 if (leaves != NULL)
00259 {
00260 delete [] leaves;
00261 }
00262 };
00263
00264 template<typename TreeData> CLeaf<TreeData> * CTree<TreeData>::getLeaf(int index)
00265 {
00266 return leaves[index];
00267 }
00268
00269 template<typename TreeData> void CTree<TreeData>::addNewInput(int index, CSplittingConditionFactory *splittingFactory)
00270 {
00271 ColumnVector *newInput = (*inputData)[index];
00272 CLeaf<TreeData> *leaf = getLeaf(newInput);
00273
00274 if (leaf != NULL && leaf->getParent() != NULL)
00275 {
00276 CNode<TreeData> *parent = dynamic_cast<CNode<TreeData> *>(leaf->getParent());
00277
00278 bool isLeft = parent->getLeftElement() == leaf;
00279 DataSubset *subset = leaf->getDataSet();
00280 subset->insert(index);
00281
00282 CTreeElement<TreeData> *newNode = createNode(parent, getInputData(), subset, splittingFactory, getDataFactory());
00283
00284 if (isLeft)
00285 {
00286 parent->setLeftElement(newNode);
00287 }
00288 else
00289 {
00290 parent->setRightElement(newNode);
00291 }
00292 delete leaf;
00293 numLeaves --;
00294
00295
00296 }
00297 else
00298 {
00299 if (root != NULL)
00300 {
00301 delete root;
00302 }
00303
00304 createTree(getInputData(), splittingFactory, getDataFactory(), false);
00305 }
00306
00307 }
00308
00309 template<typename TreeData> int CTree<TreeData>::getNumLeaves()
00310 {
00311 return numLeaves;
00312 }
00313
00314 template<typename TreeData> int CTree<TreeData>::getDepth()
00315 {
00316 return root->getDepth();
00317 }
00318
00319 template<typename TreeData> CTree<TreeData>::CTree(CDataSet *inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory) : CMapping<TreeData>(inputData->getNumDimensions())
00320 {
00321 createTree(inputData, splittingFactory, l_dataFactory);
00322 };
00323
00324
00325 template<typename TreeData> TreeData CTree<TreeData>::doGetOutputValue(ColumnVector *input)
00326 {
00327 return root->getLeaf(input)->getTreeData();
00328 };
00329
00330 template<typename TreeData> CLeaf<TreeData> * CTree<TreeData>::getLeaf(ColumnVector *input)
00331 {
00332 if (root)
00333 {
00334 ColumnVector *l_input = CMapping<TreeData>::getPreprocessedInput(input);
00335
00336 return root->getLeaf(l_input);
00337 }
00338 else
00339 {
00340 return NULL;
00341 }
00342 };
00343
00344
00345 template<typename TreeData> void CTree<TreeData>::createTree(CDataSet *l_inputData, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *l_dataFactory, bool setLeaves)
00346 {
00347 inputData = l_inputData;
00348 dataFactory = l_dataFactory;
00349
00350 DataSubset subset;
00351 numLeaves = 0;
00352
00353 if (l_inputData->size() > 0)
00354 {
00355 for (unsigned int i = 0; i < l_inputData->size(); i ++)
00356 {
00357 subset.insert(i);
00358 }
00359 dataFactory = l_dataFactory;
00360 root = createNode(NULL, l_inputData, &subset, splittingFactory, l_dataFactory);
00361
00362 if (setLeaves)
00363 {
00364 createLeavesArray();
00365 }
00366 }
00367 }
00368
00369 template<typename TreeData> void CTree<TreeData>::createLeavesArray()
00370 {
00371 if (leaves)
00372 {
00373 delete leaves;
00374 }
00375
00376 leaves = new CLeaf<TreeData>*[numLeaves];
00377 setLeaves(root, 0);
00378 }
00379
00380 template<typename TreeData> int CTree<TreeData>::getNumSamples()
00381 {
00382 int samples = 0;
00383 for (int i = 0; i < numLeaves; i ++)
00384 {
00385 samples += leaves[i]->getDataSet()->size();
00386 }
00387 return samples;
00388 }
00389
00390
00391
00392 template<typename TreeData> int CTree<TreeData>::setLeaves(CTreeElement<TreeData> *element, int numLeaf)
00393 {
00394 if (element->isLeaf())
00395 {
00396 CLeaf<TreeData> *leaf = dynamic_cast<CLeaf<TreeData> *>(element);
00397 leaves[numLeaf] = leaf;
00398 return 1;
00399 }
00400 else
00401 {
00402 int newLeaves = 0;
00403 CNode<TreeData> *node = dynamic_cast<CNode<TreeData> *>(element);
00404
00405 newLeaves += setLeaves(node->getLeftElement(), numLeaf);
00406
00407 newLeaves += setLeaves(node->getRightElement(), numLeaf + newLeaves);
00408
00409 return newLeaves;
00410 }
00411
00412 }
00413
00414 template<typename TreeData> CTreeElement<TreeData> *CTree<TreeData>::createNode(CTreeElement<TreeData> *parent,CDataSet *inputData, DataSubset *inputDataSubset, CSplittingConditionFactory *splittingFactory, CTreeDataFactory<TreeData> *dataFactory)
00415 {
00416
00417
00418 bool isLeaf = splittingFactory->isLeaf(inputDataSubset);
00419 CTreeElement<TreeData> *newNode = NULL;
00420
00421 if (isLeaf)
00422 {
00423 TreeData data = dataFactory->createTreeData(inputDataSubset, numLeaves);
00424 newNode = new CLeaf<TreeData>(parent, data,inputDataSubset, numLeaves, dataFactory);
00425 numLeaves ++;
00426 }
00427 else
00428 {
00429 CSplittingCondition *split = splittingFactory->createSplittingCondition(inputDataSubset);
00430
00431 DataSubset leftSet;
00432 DataSubset rightSet;
00433
00434 DataSubset::iterator it = inputDataSubset->begin();
00435
00436 for (; it != inputDataSubset->end(); it++)
00437 {
00438 ColumnVector *input = (*inputData)[*it];
00439
00440 if (split->isLeftNode(input))
00441 {
00442 leftSet.insert(*it);
00443 }
00444 else
00445 {
00446 rightSet.insert(*it);
00447 }
00448 }
00449
00450
00451 CNode<TreeData> *newNode1 = new CNode<TreeData>(parent, split, NULL, NULL);
00452
00453 CTreeElement<TreeData> *leftElement = createNode(newNode1, inputData, &leftSet, splittingFactory, dataFactory);
00454
00455 CTreeElement<TreeData> *rightElement = createNode(newNode1, inputData, &rightSet, splittingFactory, dataFactory);
00456
00457 newNode1->setLeftElement(leftElement);
00458 newNode1->setRightElement(rightElement);
00459
00460 newNode = newNode1;
00461
00462 }
00463 return newNode;
00464 }
00465
00466
00467 #endif
00468