From a6341d33788694ac2f9166a66301531b628c0af8 Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Thu, 8 Jan 2009 09:18:14 -0600 Subject: [PATCH] Generic KD Tree --- src/mykdtree.hpp | 122 +++++++++++++++++++ src/mykdtree.tcc | 307 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 429 insertions(+) create mode 100644 src/mykdtree.hpp create mode 100644 src/mykdtree.tcc diff --git a/src/mykdtree.hpp b/src/mykdtree.hpp new file mode 100644 index 0000000..5734a57 --- /dev/null +++ b/src/mykdtree.hpp @@ -0,0 +1,122 @@ +#ifndef __HV_KDTREE_HPP +#define __HV_KDTREE_HPP + +#include +#include "config.hpp" + +namespace CosmoTool { + + template + struct NGBDef + { + typedef float CoordType; + typedef float NGBCoordinates[N]; + static const int NumCubes = 1 << (N); + }; + + template + struct NGBCell + { + bool active; + ValType val; + typename NGBDef::NGBCoordinates coord; + }; + + class NotEnoughCells: public Exception + { + public: + NotEnoughCells() : Exception() {} + ~NotEnoughCells() throw () {} + }; + + template + struct NGBTreeNode + { + NGBCell *value; + NGBTreeNode *children[2]; + typename NGBDef::NGBCoordinates minBound, maxBound; + }; + + template + class RecursionInfoCells + { + public: + + typename NGBDef::NGBCoordinates x; + typename NGBDef::CoordType r, r2; + NGBCell **cells; + uint32_t currentRank; + uint32_t numCells; + }; + + template + class NGBTree + { + public: + static const int NumCubes = NGBDef::NumCubes; + + public: + typedef typename NGBDef::CoordType CoordType; + typedef typename NGBDef::NGBCoordinates coords; + typedef typename CosmoQueue*> NGBQueue; + + NGBTree(NGBCell *cells, uint32_t Ncells); + ~NGBTree(); + + uint32_t getIntersection(const coords& x, CoordType r, + NGBCell **cells, + uint32_t numCells) + throw (NotEnoughCells); + + NGBCell *getNearestNeighbour(const coords& x); + void getNearestNeighbours(const coords& x, uint32_t N, + NGBCell **cells); + + NGBTreeNode *getRoot() { return root; } + + void optimize(); + + NGBTreeNode *getAllNodes() { return nodes; } + uint32_t getNumNodes() const { return lastNode; } + + uint32_t countActives() const; + + protected: + NGBTreeNode *nodes; + uint32_t numNodes; + uint32_t lastNode; + + NGBTreeNode *root; + NGBCell **sortingHelper; + + NGBTreeNode *buildTree(NGBCell **cell0, + uint32_t N, + uint32_t depth, + coords minBound, + coords maxBound); + + void recursiveIntersectionCells(RecursionInfoCells& info, + NGBTreeNode *node, + int level) + throw (NotEnoughCells); + + CoordType computeDistance(NGBCell *cell, const coords& x); + void recursiveNearest(NGBTreeNode *node, + int level, + const coords& x, + CoordType& R2, + NGBCell*& cell); + void recursiveMultiNearest(NGBTreeNode *node, + int level, + const coords& x, + NGBQueue& q); + }; + + template + uint32_t gatherActiveCells(NGBCell **cells, uint32_t numCells); + +}; + +#include "mykdtree.tcc" + +#endif diff --git a/src/mykdtree.tcc b/src/mykdtree.tcc new file mode 100644 index 0000000..feb7a42 --- /dev/null +++ b/src/mykdtree.tcc @@ -0,0 +1,307 @@ +#include +#include +#include +#include + +namespace CosmoTool { + + template + class CellCompare + { + public: + CellCompare(int k) + { + rank = k; + } + + bool operator()(const NGBCell *a, const NGBCell *b) const + { + return (a->coord[rank] < b->coord[rank]); + } + protected: + int rank; + }; + + template + NGBTree::~NGBTree() + { + } + + template + NGBTree::NGBTree(NGBCell *cells, uint32_t Ncells) + { + + numNodes = Ncells; + nodes = new NGBTreeNode[numNodes]; + + sortingHelper = new NGBCell *[Ncells]; + for (uint32_t i = 0; i < Ncells; i++) + sortingHelper[i] = &cells[i]; + + optimize(); + } + + template + void NGBTree::optimize() + { + coords absoluteMin, absoluteMax; + + std::cout << "Optimizing the tree..." << std::endl; + uint32_t activeCells = gatherActiveCells(sortingHelper, numNodes); + std::cout << " number of active cells = " << activeCells << std::endl; + + lastNode = 0; + for (int i = 0; i < N; i++) + { + absoluteMin[i] = -std::numeric_limits::max(); + absoluteMax[i] = std::numeric_limits::max(); + } + + std::cout << " rebuilding the tree..." << std::endl; + root = buildTree(sortingHelper, activeCells, 0, absoluteMin, absoluteMax); + std::cout << " done." << std::endl; + } + + template + uint32_t NGBTree::getIntersection(const coords& x, CoordType r, + NGBCell **cells, + uint32_t numCells) + throw (NotEnoughCells) + { + RecursionInfoCells info; + + memcpy(info.x, x, sizeof(x)); + info.r = r; + info.r2 = r*r; + info.cells = cells; + info.currentRank = 0; + info.numCells = numCells; + + recursiveIntersectionCells(info, root, 0); + return info.currentRank; + } + + template + void NGBTree::recursiveIntersectionCells(RecursionInfoCells& info, + NGBTreeNode *node, + int level) + throw (NotEnoughCells) + { + int axis = level % N; + CoordType d2 = 0; + + if (node->value->active) + { + for (int j = 0; j < 3; j++) + { + CoordType delta = info.x[j]-node->value->coord[j]; + d2 += delta*delta; + } + if (d2 < info.r2) + { + if (info.currentRank == info.numCells) + throw NotEnoughCells(); + info.cells[info.currentRank++] = node->value; + } + } + + // The hypersphere intersects the left child node + if (((info.x[axis]+info.r) > node->minBound[axis]) && + ((info.x[axis]-info.r) < node->value->coord[axis])) + { + if (node->children[0] != 0) + recursiveIntersectionCells(info, node->children[0], + level+1); + } + if (((info.x[axis]+info.r) > node->value->coord[axis]) && + ((info.x[axis]-info.r) < node->maxBound[axis])) + { + if (node->children[1] != 0) + recursiveIntersectionCells(info, node->children[1], + level+1); + } + } + + template + uint32_t gatherActiveCells(NGBCell **cells, + uint32_t Ncells) + { + uint32_t swapId = Ncells-1; + uint32_t i = 0; + + while (!cells[swapId]->active && swapId > 0) + swapId--; + + while (i < swapId) + { + if (!cells[i]->active) + { + std::swap(cells[i], cells[swapId]); + while (!cells[swapId]->active && swapId > i) + { + swapId--; + } + } + i++; + } + return swapId+1; + } + + template + NGBTreeNode *NGBTree::buildTree(NGBCell **cell0, + uint32_t Ncells, + uint32_t depth, + coords minBound, + coords maxBound) + { + if (Ncells == 0) + return 0; + + int axis = depth % N; + NGBTreeNode *node = &nodes[lastNode++]; + uint32_t mid = Ncells/2; + coords tmpBound; + + // Isolate the environment + { + CellCompare compare(axis); + std::sort(cell0, cell0+Ncells, compare); + } + + node->value = *(cell0+mid); + memcpy(&node->minBound[0], &minBound[0], sizeof(coords)); + memcpy(&node->maxBound[0], &maxBound[0], sizeof(coords)); + + memcpy(tmpBound, maxBound, sizeof(coords)); + tmpBound[axis] = node->value->coord[axis]; + + depth++; + node->children[0] = buildTree(cell0, mid, depth, minBound, tmpBound); + + memcpy(tmpBound, minBound, sizeof(coords)); + tmpBound[axis] = node->value->coord[axis]; + node->children[1] = buildTree(cell0+mid+1, Ncells-mid-1, depth, + tmpBound, maxBound); + + return node; + } + + template + uint32_t NGBTree::countActives() const + { + uint32_t numActive = 0; + for (uint32_t i = 0; i < lastNode; i++) + { + if (nodes[i].value->active) + numActive++; + } + return numActive; + } + + template + typename NGBDef::CoordType + NGBTree::computeDistance(NGBCell *cell, const coords& x) + { + CoordType d2 = 0; + + for (int i = 0; i < N; i++) + { + CoordType delta = cell->coord[i] - x[i]; + d2 += delta*delta; + } + return d2; + } + + template + void + NGBTree::recursiveNearest( + NGBTreeNode *node, + int level, + const coords& x, + CoordType& R2, + NGBCell *& best) + { + CoordType d2 = 0; + int axis = level % N; + NGBTreeNode *other, *go; + + if (x[axis] < node->value->coord[axis]) + { + // The best is potentially in 0. + go = node->children[0]; + other = node->children[1]; + } + else + { + // If not it is in 1. + go = node->children[1]; + other = node->children[0]; + if (go == 0) + { + go = other; + other = 0; + } + } + + if (go != 0) + { + recursiveNearest(go, level+1, + x, R2,best); + } + else + { + CoordType thisR2 = computeDistance(node->value, x); + if (thisR2 < R2) + { + R2 = thisR2; + best = node->value; + } + return; + } + + // Check if current node is not the nearest + CoordType thisR2 = + computeDistance(node->value, x); + + if (thisR2 < R2) + { + R2 = thisR2; + best = node->value; + } + + // Now we found the best. We check whether the hypersphere + // intersect the hyperplane of the other branch + + CoordType delta1; + + delta1 = x[axis]-node->value->coord[axis]; + if (delta1*delta1 < R2) + { + // The hypersphere intersects the hyperplane. Try the + // other branch + if (other != 0) + { + recursiveNearest(other, level+1, x, R2, best); + } + } + } + + template + NGBCell * + NGBTree::getNearestNeighbour(const coords& x) + { + CoordType R2 = INFINITY; + NGBCell *best = 0; + + recursiveNearest(root, 0, x, R2, best); + + return best; + } + + template + void NGBTree::getNearestNeighbours(const coords& x, uint32_t N2, + NGBCell **cells) + { + } + +};