diff --git a/src/mykdtree.hpp b/src/mykdtree.hpp index 56a9898..e87b376 100644 --- a/src/mykdtree.hpp +++ b/src/mykdtree.hpp @@ -8,19 +8,19 @@ namespace CosmoTool { template - struct NGBDef + struct KDDef { typedef float CoordType; - typedef float NGBCoordinates[N]; + typedef float KDCoordinates[N]; static const int NumCubes = 1 << (N); }; template - struct NGBCell + struct KDCell { bool active; ValType val; - typename NGBDef::NGBCoordinates coord; + typename KDDef::KDCoordinates coord; }; class NotEnoughCells: public Exception @@ -31,11 +31,11 @@ namespace CosmoTool { }; template - struct NGBTreeNode + struct KDTreeNode { - NGBCell *value; - NGBTreeNode *children[2]; - typename NGBDef::NGBCoordinates minBound, maxBound; + KDCell *value; + KDTreeNode *children[2]; + typename KDDef::KDCoordinates minBound, maxBound; }; template @@ -43,9 +43,10 @@ namespace CosmoTool { { public: - typename NGBDef::NGBCoordinates x; - typename NGBDef::CoordType r, r2; - NGBCell **cells; + typename KDDef::KDCoordinates x; + typename KDDef::CoordType r, r2; + KDCell **cells; + typename KDDef::CoordType *distances; uint32_t currentRank; uint32_t numCells; }; @@ -54,12 +55,12 @@ namespace CosmoTool { class RecursionMultipleInfo { public: - const typename NGBDef::NGBCoordinates& x; - BoundedQueue< NGBCell *, float> queue; + const typename KDDef::KDCoordinates& x; + BoundedQueue< KDCell *, typename KDDef::CoordType> queue; int traversed; - RecursionMultipleInfo(const typename NGBDef::NGBCoordinates& rx, - NGBCell **cells, + RecursionMultipleInfo(const typename KDDef::KDCoordinates& rx, + KDCell **cells, uint32_t numCells) : x(rx), queue(cells, numCells, INFINITY),traversed(0) { @@ -67,67 +68,76 @@ namespace CosmoTool { }; template - class NGBTree + class KDTree { public: - static const int NumCubes = NGBDef::NumCubes; + static const int NumCubes = KDDef::NumCubes; public: - typedef typename NGBDef::CoordType CoordType; - typedef typename NGBDef::NGBCoordinates coords; + typedef typename KDDef::CoordType CoordType; + typedef typename KDDef::KDCoordinates coords; - NGBTree(NGBCell *cells, uint32_t Ncells); - ~NGBTree(); + KDTree(KDCell *cells, uint32_t Ncells); + ~KDTree(); uint32_t getIntersection(const coords& x, CoordType r, - NGBCell **cells, + KDCell **cells, + uint32_t numCells) + throw (NotEnoughCells); + uint32_t getIntersection(const coords& x, CoordType r, + KDCell **cells, + CoordType *distances, uint32_t numCells) throw (NotEnoughCells); - NGBCell *getNearestNeighbour(const coords& x); - void getNearestNeighbours(const coords& x, uint32_t N, - NGBCell **cells); + KDCell *getNearestNeighbour(const coords& x); - NGBTreeNode *getRoot() { return root; } + void getNearestNeighbours(const coords& x, uint32_t N, + KDCell **cells); + void getNearestNeighbours(const coords& x, uint32_t N, + KDCell **cells, + CoordType *distances); + + KDTreeNode *getRoot() { return root; } void optimize(); - NGBTreeNode *getAllNodes() { return nodes; } + KDTreeNode *getAllNodes() { return nodes; } uint32_t getNumNodes() const { return lastNode; } uint32_t countActives() const; protected: - NGBTreeNode *nodes; + KDTreeNode *nodes; uint32_t numNodes; uint32_t lastNode; - NGBTreeNode *root; - NGBCell **sortingHelper; + KDTreeNode *root; + KDCell **sortingHelper; - NGBTreeNode *buildTree(NGBCell **cell0, + KDTreeNode *buildTree(KDCell **cell0, uint32_t N, uint32_t depth, coords minBound, coords maxBound); void recursiveIntersectionCells(RecursionInfoCells& info, - NGBTreeNode *node, + KDTreeNode *node, int level) throw (NotEnoughCells); - CoordType computeDistance(NGBCell *cell, const coords& x); - void recursiveNearest(NGBTreeNode *node, + CoordType computeDistance(KDCell *cell, const coords& x); + void recursiveNearest(KDTreeNode *node, int level, const coords& x, CoordType& R2, - NGBCell*& cell); - void recursiveMultipleNearest(RecursionMultipleInfo& info, NGBTreeNode *node, + KDCell*& cell); + void recursiveMultipleNearest(RecursionMultipleInfo& info, KDTreeNode *node, int level); }; template - uint32_t gatherActiveCells(NGBCell **cells, uint32_t numCells); + uint32_t gatherActiveCells(KDCell **cells, uint32_t numCells); }; diff --git a/src/mykdtree.tcc b/src/mykdtree.tcc index 924e933..66ec29e 100644 --- a/src/mykdtree.tcc +++ b/src/mykdtree.tcc @@ -14,7 +14,7 @@ namespace CosmoTool { rank = k; } - bool operator()(const NGBCell *a, const NGBCell *b) const + bool operator()(const KDCell *a, const KDCell *b) const { return (a->coord[rank] < b->coord[rank]); } @@ -23,18 +23,18 @@ namespace CosmoTool { }; template - NGBTree::~NGBTree() + KDTree::~KDTree() { } template - NGBTree::NGBTree(NGBCell *cells, uint32_t Ncells) + KDTree::KDTree(KDCell *cells, uint32_t Ncells) { numNodes = Ncells; - nodes = new NGBTreeNode[numNodes]; + nodes = new KDTreeNode[numNodes]; - sortingHelper = new NGBCell *[Ncells]; + sortingHelper = new KDCell *[Ncells]; for (uint32_t i = 0; i < Ncells; i++) sortingHelper[i] = &cells[i]; @@ -42,7 +42,7 @@ namespace CosmoTool { } template - void NGBTree::optimize() + void KDTree::optimize() { coords absoluteMin, absoluteMax; @@ -63,8 +63,8 @@ namespace CosmoTool { } template - uint32_t NGBTree::getIntersection(const coords& x, CoordType r, - NGBCell **cells, + uint32_t KDTree::getIntersection(const coords& x, CoordType r, + KDCell **cells, uint32_t numCells) throw (NotEnoughCells) { @@ -76,14 +76,36 @@ namespace CosmoTool { info.cells = cells; info.currentRank = 0; info.numCells = numCells; + info.distances = 0; recursiveIntersectionCells(info, root, 0); return info.currentRank; } template - void NGBTree::recursiveIntersectionCells(RecursionInfoCells& info, - NGBTreeNode *node, + uint32_t KDTree::getIntersection(const coords& x, CoordType r, + KDCell **cells, + CoordType *distances, + uint32_t numCells) + throw (NotEnoughCells) + { + RecursionInfoCells info; + + memcpy(info.x, x, sizeof(x)); + info.r = r; + info.r2 = r*r; + info.cells = cells; + info.currentRank = 0; + info.numCells = numCells; + info.distances = distances; + + recursiveIntersectionCells(info, root, 0); + return info.currentRank; + } + + template + void KDTree::recursiveIntersectionCells(RecursionInfoCells& info, + KDTreeNode *node, int level) throw (NotEnoughCells) { @@ -101,7 +123,10 @@ namespace CosmoTool { { if (info.currentRank == info.numCells) throw NotEnoughCells(); - info.cells[info.currentRank++] = node->value; + info.cells[info.currentRank] = node->value; + if (info.distances) + info.distances[info.currentRank] = d2; + info.currentRank++; } } @@ -123,7 +148,7 @@ namespace CosmoTool { } template - uint32_t gatherActiveCells(NGBCell **cells, + uint32_t gatherActiveCells(KDCell **cells, uint32_t Ncells) { uint32_t swapId = Ncells-1; @@ -148,7 +173,7 @@ namespace CosmoTool { } template - NGBTreeNode *NGBTree::buildTree(NGBCell **cell0, + KDTreeNode *KDTree::buildTree(KDCell **cell0, uint32_t Ncells, uint32_t depth, coords minBound, @@ -158,7 +183,7 @@ namespace CosmoTool { return 0; int axis = depth % N; - NGBTreeNode *node = &nodes[lastNode++]; + KDTreeNode *node = &nodes[lastNode++]; uint32_t mid = Ncells/2; coords tmpBound; @@ -187,7 +212,7 @@ namespace CosmoTool { } template - uint32_t NGBTree::countActives() const + uint32_t KDTree::countActives() const { uint32_t numActive = 0; for (uint32_t i = 0; i < lastNode; i++) @@ -199,8 +224,8 @@ namespace CosmoTool { } template - typename NGBDef::CoordType - NGBTree::computeDistance(NGBCell *cell, const coords& x) + typename KDDef::CoordType + KDTree::computeDistance(KDCell *cell, const coords& x) { CoordType d2 = 0; @@ -214,16 +239,16 @@ namespace CosmoTool { template void - NGBTree::recursiveNearest( - NGBTreeNode *node, + KDTree::recursiveNearest( + KDTreeNode *node, int level, const coords& x, CoordType& R2, - NGBCell *& best) + KDCell *& best) { CoordType d2 = 0; int axis = level % N; - NGBTreeNode *other, *go; + KDTreeNode *other, *go; if (x[axis] < node->value->coord[axis]) { @@ -287,11 +312,11 @@ namespace CosmoTool { } template - NGBCell * - NGBTree::getNearestNeighbour(const coords& x) + KDCell * + KDTree::getNearestNeighbour(const coords& x) { CoordType R2 = INFINITY; - NGBCell *best = 0; + KDCell *best = 0; recursiveNearest(root, 0, x, R2, best); @@ -300,12 +325,12 @@ namespace CosmoTool { template void - NGBTree::recursiveMultipleNearest(RecursionMultipleInfo& info, NGBTreeNode *node, + KDTree::recursiveMultipleNearest(RecursionMultipleInfo& info, KDTreeNode *node, int level) { CoordType d2 = 0; int axis = level % N; - NGBTreeNode *other, *go; + KDTreeNode *other, *go; if (info.x[axis] < node->value->coord[axis]) { @@ -356,8 +381,8 @@ namespace CosmoTool { } template - void NGBTree::getNearestNeighbours(const coords& x, uint32_t N2, - NGBCell **cells) + void KDTree::getNearestNeighbours(const coords& x, uint32_t N2, + KDCell **cells) { RecursionMultipleInfo info(x, cells, N2); @@ -369,4 +394,20 @@ namespace CosmoTool { std::cout << "Traversed = " << info.traversed << std::endl; } + template + void KDTree::getNearestNeighbours(const coords& x, uint32_t N2, + KDCell **cells, + CoordType *distances) + { + RecursionMultipleInfo info(x, cells, N2); + + for (int i = 0; i < N2; i++) + cells[i] = 0; + + recursiveMultipleNearest(info, root, 0); + memcpy(distances, info.getPriorities(), sizeof(CoordType)*N2); + + std::cout << "Traversed = " << info.traversed << std::endl; + } + };