Reinforcement Learning Toolbox 2.0
last updated:
General
Documentation
Manual
Tutorial
Class Reference
Master Thesis
Examples
Related Papers
Downloads
Links
News
mailto:webmaster
Main Page     Class Hierarchy   Compound List   File List   Compound Members   File Members

cnearestneighbor.h

Go to the documentation of this file.
00001 #ifndef C_NEAREST_NEIGHBOR__H
00002 #define C_NEAREST_NEIGHBOR__H
00003 
00004 #include "ctrees.h"
00005 #include "cinputdata.h"
00006 #include "newmat/newmatio.h"
00007 
00008 #include <iostream>
00009 #include <limits>
00010 
00011 class CKDRectangle 
00012 {
00013         protected:
00014                 ColumnVector *minValues;
00015                 ColumnVector *maxValues;
00016         public:
00017                 CKDRectangle(int numDim);
00018                 CKDRectangle(CKDRectangle &rectangle);
00019 
00020                 ColumnVector *getMinVector();
00021                 ColumnVector *getMaxVector();
00022 
00023                 virtual ~CKDRectangle();                
00024 
00025                 void setMaxValue(int dim, double value);
00026                 void setMinValue(int dim, double value);
00027 
00028                 double getMinValue(int dim);
00029                 double getMaxValue(int dim);
00030 
00031                 double getDistanceToPoint(ColumnVector *point);
00032 
00033                 bool intersects(CKDRectangle *rectangle);
00034 };
00035 
00036 template<typename DataElement, typename TreeData> class CKNearestNeighborsTreeData
00037 {
00038         protected:
00039                 double *distList;
00040                 DataElement *elementList;
00041 
00042                 int K;
00043 
00044                 CTree<TreeData> *tree;
00045 
00046                 void getNearestNeighborsElements(ColumnVector *point, CTreeElement<TreeData> *element, CKDRectangle *rectangle);
00047 
00048                 
00049                 virtual void addAndSortDataElements(DataElement element, double distance);
00050                 virtual void addDataElements(ColumnVector *point, CLeaf<TreeData> *leaf, CKDRectangle *rectangle) = 0;
00051         public:
00052                 CKNearestNeighborsTreeData(CTree<TreeData> *tree, int K);
00053                 virtual ~CKNearestNeighborsTreeData();
00054 
00055                 void getNearestNeighbors(ColumnVector *point,  std::list<DataElement> *elementList, int K = -1, ColumnVector *distances = NULL);
00056                 
00057                 void getNearestNeighborDistance(ColumnVector *input, DataElement &nearestNeighbor, double &distance);
00058 };
00059 
00060 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::getNearestNeighborsElements(ColumnVector *point,CTreeElement<TreeData> *element, CKDRectangle *rectangle)
00061 {
00062         if (element->isLeaf())
00063         {
00064                 //printf("Adding Leaf with %f Distance\n", rectangle->getDistanceToPoint( point));
00065                 CLeaf<TreeData> *leaf = dynamic_cast<CLeaf<TreeData> *>(element);
00066                 addDataElements(point, leaf, rectangle);
00067         }
00068         else
00069         {
00070                 CNode<TreeData> *node = dynamic_cast<CNode<TreeData> *>(element);       
00071                 C1DSplittingCondition *split = dynamic_cast<C1DSplittingCondition *>(node->getSplittingCondition());
00072 
00073                 double minValue = rectangle->getMinValue(split->getDimension());
00074                 double maxValue = rectangle->getMaxValue(split->getDimension());
00075 
00076                 double leftDist = 0;
00077                 double rightDist = 0;
00078                 // set value for left branch
00079                 rectangle->setMaxValue(split->getDimension(), split->getTreshold());
00080                 leftDist = rectangle->getDistanceToPoint(point);
00081                 
00082                 rectangle->setMaxValue(split->getDimension(), maxValue);
00083                 rectangle->setMinValue(split->getDimension(), split->getTreshold());
00084 
00085                 rightDist = rectangle->getDistanceToPoint(point);       
00086         
00087                 //printf("LeftDist: %f, RightDist: %f...", leftDist, rightDist);        
00088                 /*
00089                for (int i = 0; i < K; i ++)
00090                {
00091                        if (distList[i] < numeric_limits<double>::max())
00092                        {
00093                                printf("%f ", distList[i]);
00094                        }
00095                }
00096                printf("\n");*/
00097 
00098                 if (rightDist < leftDist)
00099                 {
00100                         // point is nearer to the right branch
00101                         if (distList[0] == numeric_limits<double>::max() || rightDist < distList[0])
00102                         {
00103                                 // search NN in right branch
00104                                 //printf("Going Right...\n");
00105                                 getNearestNeighborsElements(point,node->getRightElement(), rectangle);
00106 
00107                                 if (distList[0] == numeric_limits<double>::max() || leftDist < distList[0])
00108                                 {
00109                                 //      printf("Going Left (2)...\n");
00110                                         // also search NN in left branch
00111                                         rectangle->setMaxValue(split->getDimension(),  split->getTreshold());
00112                                         rectangle->setMinValue(split->getDimension(), minValue);
00113                         
00114                                         getNearestNeighborsElements(point,node->getLeftElement(), rectangle);
00115                 
00116                                 }
00117                         }
00118                 }
00119                 else
00120                 {
00121                         // point is nearer to the right branch
00122                         if (distList[0] == numeric_limits<double>::max() || leftDist < distList[0])
00123                         {
00124                                 //printf("Going Left...\n");
00125                                 rectangle->setMaxValue(split->getDimension(), split->getTreshold());
00126                                 rectangle->setMinValue(split->getDimension(), minValue);
00127                                 
00128                                 // search NN in left branch
00129                                 getNearestNeighborsElements(point,node->getLeftElement(), rectangle);
00130         
00131                                 if (distList[0] == numeric_limits<double>::max() || rightDist < distList[0])
00132                                 {
00133                                         //printf("Going Right (2)...\n");
00134                                         // also search NN in right branch
00135                                         rectangle->setMaxValue(split->getDimension(), maxValue);
00136                                         rectangle->setMinValue(split->getDimension(), split->getTreshold());
00137                                 
00138                                         getNearestNeighborsElements(point,node->getRightElement(), rectangle);
00139                                 }
00140                         }
00141                 }
00142                 rectangle->setMaxValue(split->getDimension(), maxValue);
00143                 rectangle->setMinValue(split->getDimension(), minValue); 
00144         }
00145 }
00146         
00147 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::addAndSortDataElements(DataElement data, double distance)
00148 {
00149         int i = 0;
00150         while (i < K && distList[i] > distance)
00151         {
00152                 i ++;
00153         } 
00154         i --;   
00155 
00156         
00157         if (i >= 0)
00158         {
00159                 for (int j = 0; j < i; j ++)
00160                 {
00161                         distList[j] = distList[j + 1];
00162                         elementList[j] = elementList[j + 1];
00163                 }
00164                 distList[i] = distance;
00165                 elementList[i] = data;
00166         }       
00167 
00168         /*printf("Adding Element with distance %f\n", distance);
00169        for (int i = 0; i < K; i ++)
00170        {
00171                if (distList[i] < numeric_limits<double>::max())
00172                {
00173                        printf("%f ", distList[i]);
00174                }
00175        }
00176        printf("\n");*/
00177 }
00178 
00179 
00180 template<typename DataElement, typename TreeData> CKNearestNeighborsTreeData<DataElement, TreeData>::CKNearestNeighborsTreeData(CTree<TreeData> *l_tree, int l_K)
00181 {
00182         tree = l_tree;
00183         K = l_K;
00184         distList = new double[K];
00185         elementList = new DataElement[K];
00186 }
00187 
00188 template<typename DataElement, typename TreeData> CKNearestNeighborsTreeData<DataElement, TreeData>::~CKNearestNeighborsTreeData()
00189 {
00190         delete [] distList;
00191         delete [] elementList;
00192 }
00193 
00194 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::getNearestNeighbors(ColumnVector *point, std::list<DataElement> *l_elementList, int l_K, ColumnVector *distances)
00195 {
00196         int tempK = K;
00197         if (l_K > 0)
00198         {
00199                 K = l_K;
00200         }
00201 
00202         for (int i = 0; i < K; i++)
00203         {
00204                 distList[i] = numeric_limits<double>::max();
00205         }
00206 
00207         ColumnVector *input = tree->getPreprocessedInput(point);
00208         
00209         CKDRectangle rectangle(tree->getNumDimensions());
00210         
00211         //printf("Starting NN search: %f for point: ", rectangle.getDistanceToPoint(point));
00212         //cout << point->t() << endl;   
00213         //cout << input->t() << endl;           
00214 
00215         getNearestNeighborsElements(input, tree->getRoot(), &rectangle);
00216         
00217         l_elementList->clear();
00218         
00219         for (int i = 0; i < K; i ++)
00220         {
00221                 if (distances)
00222                 {
00223                         distances->element( K - 1 - i) = distList[i];
00224                 }
00225                 if (distList[i] < numeric_limits<double>::max())
00226                 {
00227                         l_elementList->push_front(elementList[i]);
00228                 }
00229         }
00230 
00231         K = tempK;      
00232 }
00233 
00234 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::getNearestNeighborDistance(ColumnVector *point, DataElement &nearestNeighbor, double &distance)
00235 {
00236         int tempK = K;
00237         K = 1;
00238         
00239 
00240         for (int i = 0; i < K; i++)
00241         {
00242                 distList[i] = numeric_limits<double>::max();
00243         }
00244 
00245         ColumnVector *input = tree->getPreprocessedInput(point);
00246         
00247         CKDRectangle rectangle(tree->getNumDimensions());
00248         
00249         //printf("Starting NN search: %f for point: ", rectangle.getDistanceToPoint(point));
00250         //cout << point->t() << endl;   
00251         //cout << input->t() << endl;           
00252 
00253         getNearestNeighborsElements(input, tree->getRoot(), &rectangle);
00254         
00255         K = tempK;      
00256 
00257         nearestNeighbor = elementList[0];
00258         distance = distList[0];
00259 
00260 }
00261 
00262 
00263 
00264 class CKNearestNeighbors : public  CKNearestNeighborsTreeData<int, DataSubset *>
00265 {
00266         protected:
00267                 CDataSet *inputSet;     
00268                 ColumnVector *buffVector;
00269 
00270                 virtual void addDataElements(ColumnVector *point, CLeaf<DataSubset *> *leaf, CKDRectangle *rectangle);
00271                 
00272         public:
00273                 CKNearestNeighbors(CTree<DataSubset *> *tree, CDataSet *inputSet, int K);
00274                 virtual ~CKNearestNeighbors();
00275 
00276                 CDataSet *getInputSet() {return inputSet;};
00277 };
00278 
00279 
00280 
00281 template<typename TreeData> class CKNearestLeaves : public  CKNearestNeighborsTreeData<int, TreeData>
00282 {
00283         protected:
00284                 
00285                 virtual void addDataElements(ColumnVector *point, CLeaf<TreeData> *leaf, CKDRectangle *rectangle);
00286                 
00287         public:
00288                 CKNearestLeaves(CTree<TreeData> *tree, int K);
00289                 virtual ~CKNearestLeaves();
00290 };
00291 
00292 
00293 
00294 template<typename TreeData> void CKNearestLeaves<TreeData>::addDataElements(ColumnVector *point, CLeaf<TreeData> *leaf, CKDRectangle *rectangle) 
00295 {
00296         addAndSortDataElements(leaf->getLeafNumber(), rectangle->getDistanceToPoint(point));
00297 }
00298                 
00299 
00300 template<typename TreeData> CKNearestLeaves<TreeData>::CKNearestLeaves(CTree<TreeData> *tree, int K) : CKNearestNeighborsTreeData<int, TreeData>(tree, K)
00301 {
00302 
00303 }
00304 
00305 template<typename TreeData> CKNearestLeaves<TreeData>::~CKNearestLeaves()
00306 {
00307 }
00308 
00309 class CRangeSearch
00310 {
00311         protected:
00312                 DataSubset *elementList;
00313 
00314                 CTree<DataSubset *> *tree;
00315                 CDataSet *inputSet;
00316 
00317                 void getSamplesInRangeElements( CKDRectangle *range, CTreeElement<DataSubset *> *element, CKDRectangle *rectangle);
00318 
00319                 
00320                 virtual void addAndSortDataElements(int element);
00321                 virtual void addDataElements(CKDRectangle *range, CLeaf<DataSubset *> *leaf, CKDRectangle *leafRectangle);
00322         public:
00323                 CRangeSearch(CTree<DataSubset *> *tree, CDataSet *l_inputSet);
00324                 virtual ~CRangeSearch();
00325 
00326                 void getSamplesInRange(CKDRectangle *range, DataSubset *elementList);
00327                 
00328 };
00329 
00330 
00331 
00332 #endif
00333