00001 #ifndef C_NEAREST_NEIGHBOR__H
00002 #define C_NEAREST_NEIGHBOR__H
00003
00004 #include "ctrees.h"
00005 #include "cinputdata.h"
00006 #include "newmat/newmatio.h"
00007
00008 #include <iostream>
00009 #include <limits>
00010
00011 class CKDRectangle
00012 {
00013 protected:
00014 ColumnVector *minValues;
00015 ColumnVector *maxValues;
00016 public:
00017 CKDRectangle(int numDim);
00018 CKDRectangle(CKDRectangle &rectangle);
00019
00020 ColumnVector *getMinVector();
00021 ColumnVector *getMaxVector();
00022
00023 virtual ~CKDRectangle();
00024
00025 void setMaxValue(int dim, double value);
00026 void setMinValue(int dim, double value);
00027
00028 double getMinValue(int dim);
00029 double getMaxValue(int dim);
00030
00031 double getDistanceToPoint(ColumnVector *point);
00032
00033 bool intersects(CKDRectangle *rectangle);
00034 };
00035
00036 template<typename DataElement, typename TreeData> class CKNearestNeighborsTreeData
00037 {
00038 protected:
00039 double *distList;
00040 DataElement *elementList;
00041
00042 int K;
00043
00044 CTree<TreeData> *tree;
00045
00046 void getNearestNeighborsElements(ColumnVector *point, CTreeElement<TreeData> *element, CKDRectangle *rectangle);
00047
00048
00049 virtual void addAndSortDataElements(DataElement element, double distance);
00050 virtual void addDataElements(ColumnVector *point, CLeaf<TreeData> *leaf, CKDRectangle *rectangle) = 0;
00051 public:
00052 CKNearestNeighborsTreeData(CTree<TreeData> *tree, int K);
00053 virtual ~CKNearestNeighborsTreeData();
00054
00055 void getNearestNeighbors(ColumnVector *point, std::list<DataElement> *elementList, int K = -1, ColumnVector *distances = NULL);
00056
00057 void getNearestNeighborDistance(ColumnVector *input, DataElement &nearestNeighbor, double &distance);
00058 };
00059
00060 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::getNearestNeighborsElements(ColumnVector *point,CTreeElement<TreeData> *element, CKDRectangle *rectangle)
00061 {
00062 if (element->isLeaf())
00063 {
00064
00065 CLeaf<TreeData> *leaf = dynamic_cast<CLeaf<TreeData> *>(element);
00066 addDataElements(point, leaf, rectangle);
00067 }
00068 else
00069 {
00070 CNode<TreeData> *node = dynamic_cast<CNode<TreeData> *>(element);
00071 C1DSplittingCondition *split = dynamic_cast<C1DSplittingCondition *>(node->getSplittingCondition());
00072
00073 double minValue = rectangle->getMinValue(split->getDimension());
00074 double maxValue = rectangle->getMaxValue(split->getDimension());
00075
00076 double leftDist = 0;
00077 double rightDist = 0;
00078
00079 rectangle->setMaxValue(split->getDimension(), split->getTreshold());
00080 leftDist = rectangle->getDistanceToPoint(point);
00081
00082 rectangle->setMaxValue(split->getDimension(), maxValue);
00083 rectangle->setMinValue(split->getDimension(), split->getTreshold());
00084
00085 rightDist = rectangle->getDistanceToPoint(point);
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098 if (rightDist < leftDist)
00099 {
00100
00101 if (distList[0] == numeric_limits<double>::max() || rightDist < distList[0])
00102 {
00103
00104
00105 getNearestNeighborsElements(point,node->getRightElement(), rectangle);
00106
00107 if (distList[0] == numeric_limits<double>::max() || leftDist < distList[0])
00108 {
00109
00110
00111 rectangle->setMaxValue(split->getDimension(), split->getTreshold());
00112 rectangle->setMinValue(split->getDimension(), minValue);
00113
00114 getNearestNeighborsElements(point,node->getLeftElement(), rectangle);
00115
00116 }
00117 }
00118 }
00119 else
00120 {
00121
00122 if (distList[0] == numeric_limits<double>::max() || leftDist < distList[0])
00123 {
00124
00125 rectangle->setMaxValue(split->getDimension(), split->getTreshold());
00126 rectangle->setMinValue(split->getDimension(), minValue);
00127
00128
00129 getNearestNeighborsElements(point,node->getLeftElement(), rectangle);
00130
00131 if (distList[0] == numeric_limits<double>::max() || rightDist < distList[0])
00132 {
00133
00134
00135 rectangle->setMaxValue(split->getDimension(), maxValue);
00136 rectangle->setMinValue(split->getDimension(), split->getTreshold());
00137
00138 getNearestNeighborsElements(point,node->getRightElement(), rectangle);
00139 }
00140 }
00141 }
00142 rectangle->setMaxValue(split->getDimension(), maxValue);
00143 rectangle->setMinValue(split->getDimension(), minValue);
00144 }
00145 }
00146
00147 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::addAndSortDataElements(DataElement data, double distance)
00148 {
00149 int i = 0;
00150 while (i < K && distList[i] > distance)
00151 {
00152 i ++;
00153 }
00154 i --;
00155
00156
00157 if (i >= 0)
00158 {
00159 for (int j = 0; j < i; j ++)
00160 {
00161 distList[j] = distList[j + 1];
00162 elementList[j] = elementList[j + 1];
00163 }
00164 distList[i] = distance;
00165 elementList[i] = data;
00166 }
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177 }
00178
00179
00180 template<typename DataElement, typename TreeData> CKNearestNeighborsTreeData<DataElement, TreeData>::CKNearestNeighborsTreeData(CTree<TreeData> *l_tree, int l_K)
00181 {
00182 tree = l_tree;
00183 K = l_K;
00184 distList = new double[K];
00185 elementList = new DataElement[K];
00186 }
00187
00188 template<typename DataElement, typename TreeData> CKNearestNeighborsTreeData<DataElement, TreeData>::~CKNearestNeighborsTreeData()
00189 {
00190 delete [] distList;
00191 delete [] elementList;
00192 }
00193
00194 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::getNearestNeighbors(ColumnVector *point, std::list<DataElement> *l_elementList, int l_K, ColumnVector *distances)
00195 {
00196 int tempK = K;
00197 if (l_K > 0)
00198 {
00199 K = l_K;
00200 }
00201
00202 for (int i = 0; i < K; i++)
00203 {
00204 distList[i] = numeric_limits<double>::max();
00205 }
00206
00207 ColumnVector *input = tree->getPreprocessedInput(point);
00208
00209 CKDRectangle rectangle(tree->getNumDimensions());
00210
00211
00212
00213
00214
00215 getNearestNeighborsElements(input, tree->getRoot(), &rectangle);
00216
00217 l_elementList->clear();
00218
00219 for (int i = 0; i < K; i ++)
00220 {
00221 if (distances)
00222 {
00223 distances->element( K - 1 - i) = distList[i];
00224 }
00225 if (distList[i] < numeric_limits<double>::max())
00226 {
00227 l_elementList->push_front(elementList[i]);
00228 }
00229 }
00230
00231 K = tempK;
00232 }
00233
00234 template<typename DataElement, typename TreeData> void CKNearestNeighborsTreeData<DataElement, TreeData>::getNearestNeighborDistance(ColumnVector *point, DataElement &nearestNeighbor, double &distance)
00235 {
00236 int tempK = K;
00237 K = 1;
00238
00239
00240 for (int i = 0; i < K; i++)
00241 {
00242 distList[i] = numeric_limits<double>::max();
00243 }
00244
00245 ColumnVector *input = tree->getPreprocessedInput(point);
00246
00247 CKDRectangle rectangle(tree->getNumDimensions());
00248
00249
00250
00251
00252
00253 getNearestNeighborsElements(input, tree->getRoot(), &rectangle);
00254
00255 K = tempK;
00256
00257 nearestNeighbor = elementList[0];
00258 distance = distList[0];
00259
00260 }
00261
00262
00263
00264 class CKNearestNeighbors : public CKNearestNeighborsTreeData<int, DataSubset *>
00265 {
00266 protected:
00267 CDataSet *inputSet;
00268 ColumnVector *buffVector;
00269
00270 virtual void addDataElements(ColumnVector *point, CLeaf<DataSubset *> *leaf, CKDRectangle *rectangle);
00271
00272 public:
00273 CKNearestNeighbors(CTree<DataSubset *> *tree, CDataSet *inputSet, int K);
00274 virtual ~CKNearestNeighbors();
00275
00276 CDataSet *getInputSet() {return inputSet;};
00277 };
00278
00279
00280
00281 template<typename TreeData> class CKNearestLeaves : public CKNearestNeighborsTreeData<int, TreeData>
00282 {
00283 protected:
00284
00285 virtual void addDataElements(ColumnVector *point, CLeaf<TreeData> *leaf, CKDRectangle *rectangle);
00286
00287 public:
00288 CKNearestLeaves(CTree<TreeData> *tree, int K);
00289 virtual ~CKNearestLeaves();
00290 };
00291
00292
00293
00294 template<typename TreeData> void CKNearestLeaves<TreeData>::addDataElements(ColumnVector *point, CLeaf<TreeData> *leaf, CKDRectangle *rectangle)
00295 {
00296 addAndSortDataElements(leaf->getLeafNumber(), rectangle->getDistanceToPoint(point));
00297 }
00298
00299
00300 template<typename TreeData> CKNearestLeaves<TreeData>::CKNearestLeaves(CTree<TreeData> *tree, int K) : CKNearestNeighborsTreeData<int, TreeData>(tree, K)
00301 {
00302
00303 }
00304
00305 template<typename TreeData> CKNearestLeaves<TreeData>::~CKNearestLeaves()
00306 {
00307 }
00308
00309 class CRangeSearch
00310 {
00311 protected:
00312 DataSubset *elementList;
00313
00314 CTree<DataSubset *> *tree;
00315 CDataSet *inputSet;
00316
00317 void getSamplesInRangeElements( CKDRectangle *range, CTreeElement<DataSubset *> *element, CKDRectangle *rectangle);
00318
00319
00320 virtual void addAndSortDataElements(int element);
00321 virtual void addDataElements(CKDRectangle *range, CLeaf<DataSubset *> *leaf, CKDRectangle *leafRectangle);
00322 public:
00323 CRangeSearch(CTree<DataSubset *> *tree, CDataSet *l_inputSet);
00324 virtual ~CRangeSearch();
00325
00326 void getSamplesInRange(CKDRectangle *range, DataSubset *elementList);
00327
00328 };
00329
00330
00331
00332 #endif
00333