Generic KD Tree
This commit is contained in:
parent
de35eeabd8
commit
a6341d3378
122
src/mykdtree.hpp
Normal file
122
src/mykdtree.hpp
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
#ifndef __HV_KDTREE_HPP
|
||||||
|
#define __HV_KDTREE_HPP
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include "config.hpp"
|
||||||
|
|
||||||
|
namespace CosmoTool {
|
||||||
|
|
||||||
|
template<int N>
|
||||||
|
struct NGBDef
|
||||||
|
{
|
||||||
|
typedef float CoordType;
|
||||||
|
typedef float NGBCoordinates[N];
|
||||||
|
static const int NumCubes = 1 << (N);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
struct NGBCell
|
||||||
|
{
|
||||||
|
bool active;
|
||||||
|
ValType val;
|
||||||
|
typename NGBDef<N>::NGBCoordinates coord;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NotEnoughCells: public Exception
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NotEnoughCells() : Exception() {}
|
||||||
|
~NotEnoughCells() throw () {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
struct NGBTreeNode
|
||||||
|
{
|
||||||
|
NGBCell<N,ValType> *value;
|
||||||
|
NGBTreeNode<N,ValType> *children[2];
|
||||||
|
typename NGBDef<N>::NGBCoordinates minBound, maxBound;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
class RecursionInfoCells
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
typename NGBDef<N>::NGBCoordinates x;
|
||||||
|
typename NGBDef<N>::CoordType r, r2;
|
||||||
|
NGBCell<N, ValType> **cells;
|
||||||
|
uint32_t currentRank;
|
||||||
|
uint32_t numCells;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
class NGBTree
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static const int NumCubes = NGBDef<N>::NumCubes;
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef typename NGBDef<N>::CoordType CoordType;
|
||||||
|
typedef typename NGBDef<N>::NGBCoordinates coords;
|
||||||
|
typedef typename CosmoQueue<NGBCell<N,ValType>*> NGBQueue;
|
||||||
|
|
||||||
|
NGBTree(NGBCell<N,ValType> *cells, uint32_t Ncells);
|
||||||
|
~NGBTree();
|
||||||
|
|
||||||
|
uint32_t getIntersection(const coords& x, CoordType r,
|
||||||
|
NGBCell<N, ValType> **cells,
|
||||||
|
uint32_t numCells)
|
||||||
|
throw (NotEnoughCells);
|
||||||
|
|
||||||
|
NGBCell<N, ValType> *getNearestNeighbour(const coords& x);
|
||||||
|
void getNearestNeighbours(const coords& x, uint32_t N,
|
||||||
|
NGBCell<N, ValType> **cells);
|
||||||
|
|
||||||
|
NGBTreeNode<N,ValType> *getRoot() { return root; }
|
||||||
|
|
||||||
|
void optimize();
|
||||||
|
|
||||||
|
NGBTreeNode<N,ValType> *getAllNodes() { return nodes; }
|
||||||
|
uint32_t getNumNodes() const { return lastNode; }
|
||||||
|
|
||||||
|
uint32_t countActives() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
NGBTreeNode<N, ValType> *nodes;
|
||||||
|
uint32_t numNodes;
|
||||||
|
uint32_t lastNode;
|
||||||
|
|
||||||
|
NGBTreeNode<N, ValType> *root;
|
||||||
|
NGBCell<N, ValType> **sortingHelper;
|
||||||
|
|
||||||
|
NGBTreeNode<N, ValType> *buildTree(NGBCell<N,ValType> **cell0,
|
||||||
|
uint32_t N,
|
||||||
|
uint32_t depth,
|
||||||
|
coords minBound,
|
||||||
|
coords maxBound);
|
||||||
|
|
||||||
|
void recursiveIntersectionCells(RecursionInfoCells<N,ValType>& info,
|
||||||
|
NGBTreeNode<N,ValType> *node,
|
||||||
|
int level)
|
||||||
|
throw (NotEnoughCells);
|
||||||
|
|
||||||
|
CoordType computeDistance(NGBCell<N,ValType> *cell, const coords& x);
|
||||||
|
void recursiveNearest(NGBTreeNode<N, ValType> *node,
|
||||||
|
int level,
|
||||||
|
const coords& x,
|
||||||
|
CoordType& R2,
|
||||||
|
NGBCell<N,ValType>*& cell);
|
||||||
|
void recursiveMultiNearest(NGBTreeNode<N, ValType> *node,
|
||||||
|
int level,
|
||||||
|
const coords& x,
|
||||||
|
NGBQueue& q);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int N, class T>
|
||||||
|
uint32_t gatherActiveCells(NGBCell<N,T> **cells, uint32_t numCells);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#include "mykdtree.tcc"
|
||||||
|
|
||||||
|
#endif
|
307
src/mykdtree.tcc
Normal file
307
src/mykdtree.tcc
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
#include <cstring>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace CosmoTool {
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
class CellCompare
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CellCompare(int k)
|
||||||
|
{
|
||||||
|
rank = k;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator()(const NGBCell<N,ValType> *a, const NGBCell<N,ValType> *b) const
|
||||||
|
{
|
||||||
|
return (a->coord[rank] < b->coord[rank]);
|
||||||
|
}
|
||||||
|
protected:
|
||||||
|
int rank;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
NGBTree<N,ValType>::~NGBTree()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
NGBTree<N,ValType>::NGBTree(NGBCell<N,ValType> *cells, uint32_t Ncells)
|
||||||
|
{
|
||||||
|
|
||||||
|
numNodes = Ncells;
|
||||||
|
nodes = new NGBTreeNode<N,ValType>[numNodes];
|
||||||
|
|
||||||
|
sortingHelper = new NGBCell<N,ValType> *[Ncells];
|
||||||
|
for (uint32_t i = 0; i < Ncells; i++)
|
||||||
|
sortingHelper[i] = &cells[i];
|
||||||
|
|
||||||
|
optimize();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
void NGBTree<N,ValType>::optimize()
|
||||||
|
{
|
||||||
|
coords absoluteMin, absoluteMax;
|
||||||
|
|
||||||
|
std::cout << "Optimizing the tree..." << std::endl;
|
||||||
|
uint32_t activeCells = gatherActiveCells(sortingHelper, numNodes);
|
||||||
|
std::cout << " number of active cells = " << activeCells << std::endl;
|
||||||
|
|
||||||
|
lastNode = 0;
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
absoluteMin[i] = -std::numeric_limits<typeof (absoluteMin[0])>::max();
|
||||||
|
absoluteMax[i] = std::numeric_limits<typeof (absoluteMax[0])>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << " rebuilding the tree..." << std::endl;
|
||||||
|
root = buildTree(sortingHelper, activeCells, 0, absoluteMin, absoluteMax);
|
||||||
|
std::cout << " done." << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
uint32_t NGBTree<N,ValType>::getIntersection(const coords& x, CoordType r,
|
||||||
|
NGBCell<N, ValType> **cells,
|
||||||
|
uint32_t numCells)
|
||||||
|
throw (NotEnoughCells)
|
||||||
|
{
|
||||||
|
RecursionInfoCells<N,ValType> info;
|
||||||
|
|
||||||
|
memcpy(info.x, x, sizeof(x));
|
||||||
|
info.r = r;
|
||||||
|
info.r2 = r*r;
|
||||||
|
info.cells = cells;
|
||||||
|
info.currentRank = 0;
|
||||||
|
info.numCells = numCells;
|
||||||
|
|
||||||
|
recursiveIntersectionCells(info, root, 0);
|
||||||
|
return info.currentRank;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
void NGBTree<N,ValType>::recursiveIntersectionCells(RecursionInfoCells<N,ValType>& info,
|
||||||
|
NGBTreeNode<N,ValType> *node,
|
||||||
|
int level)
|
||||||
|
throw (NotEnoughCells)
|
||||||
|
{
|
||||||
|
int axis = level % N;
|
||||||
|
CoordType d2 = 0;
|
||||||
|
|
||||||
|
if (node->value->active)
|
||||||
|
{
|
||||||
|
for (int j = 0; j < 3; j++)
|
||||||
|
{
|
||||||
|
CoordType delta = info.x[j]-node->value->coord[j];
|
||||||
|
d2 += delta*delta;
|
||||||
|
}
|
||||||
|
if (d2 < info.r2)
|
||||||
|
{
|
||||||
|
if (info.currentRank == info.numCells)
|
||||||
|
throw NotEnoughCells();
|
||||||
|
info.cells[info.currentRank++] = node->value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The hypersphere intersects the left child node
|
||||||
|
if (((info.x[axis]+info.r) > node->minBound[axis]) &&
|
||||||
|
((info.x[axis]-info.r) < node->value->coord[axis]))
|
||||||
|
{
|
||||||
|
if (node->children[0] != 0)
|
||||||
|
recursiveIntersectionCells(info, node->children[0],
|
||||||
|
level+1);
|
||||||
|
}
|
||||||
|
if (((info.x[axis]+info.r) > node->value->coord[axis]) &&
|
||||||
|
((info.x[axis]-info.r) < node->maxBound[axis]))
|
||||||
|
{
|
||||||
|
if (node->children[1] != 0)
|
||||||
|
recursiveIntersectionCells(info, node->children[1],
|
||||||
|
level+1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
uint32_t gatherActiveCells(NGBCell<N,ValType> **cells,
|
||||||
|
uint32_t Ncells)
|
||||||
|
{
|
||||||
|
uint32_t swapId = Ncells-1;
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
while (!cells[swapId]->active && swapId > 0)
|
||||||
|
swapId--;
|
||||||
|
|
||||||
|
while (i < swapId)
|
||||||
|
{
|
||||||
|
if (!cells[i]->active)
|
||||||
|
{
|
||||||
|
std::swap(cells[i], cells[swapId]);
|
||||||
|
while (!cells[swapId]->active && swapId > i)
|
||||||
|
{
|
||||||
|
swapId--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
return swapId+1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
NGBTreeNode<N, ValType> *NGBTree<N,ValType>::buildTree(NGBCell<N,ValType> **cell0,
|
||||||
|
uint32_t Ncells,
|
||||||
|
uint32_t depth,
|
||||||
|
coords minBound,
|
||||||
|
coords maxBound)
|
||||||
|
{
|
||||||
|
if (Ncells == 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
int axis = depth % N;
|
||||||
|
NGBTreeNode<N,ValType> *node = &nodes[lastNode++];
|
||||||
|
uint32_t mid = Ncells/2;
|
||||||
|
coords tmpBound;
|
||||||
|
|
||||||
|
// Isolate the environment
|
||||||
|
{
|
||||||
|
CellCompare<N,ValType> compare(axis);
|
||||||
|
std::sort(cell0, cell0+Ncells, compare);
|
||||||
|
}
|
||||||
|
|
||||||
|
node->value = *(cell0+mid);
|
||||||
|
memcpy(&node->minBound[0], &minBound[0], sizeof(coords));
|
||||||
|
memcpy(&node->maxBound[0], &maxBound[0], sizeof(coords));
|
||||||
|
|
||||||
|
memcpy(tmpBound, maxBound, sizeof(coords));
|
||||||
|
tmpBound[axis] = node->value->coord[axis];
|
||||||
|
|
||||||
|
depth++;
|
||||||
|
node->children[0] = buildTree(cell0, mid, depth, minBound, tmpBound);
|
||||||
|
|
||||||
|
memcpy(tmpBound, minBound, sizeof(coords));
|
||||||
|
tmpBound[axis] = node->value->coord[axis];
|
||||||
|
node->children[1] = buildTree(cell0+mid+1, Ncells-mid-1, depth,
|
||||||
|
tmpBound, maxBound);
|
||||||
|
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
uint32_t NGBTree<N,ValType>::countActives() const
|
||||||
|
{
|
||||||
|
uint32_t numActive = 0;
|
||||||
|
for (uint32_t i = 0; i < lastNode; i++)
|
||||||
|
{
|
||||||
|
if (nodes[i].value->active)
|
||||||
|
numActive++;
|
||||||
|
}
|
||||||
|
return numActive;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
typename NGBDef<N>::CoordType
|
||||||
|
NGBTree<N,ValType>::computeDistance(NGBCell<N, ValType> *cell, const coords& x)
|
||||||
|
{
|
||||||
|
CoordType d2 = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
CoordType delta = cell->coord[i] - x[i];
|
||||||
|
d2 += delta*delta;
|
||||||
|
}
|
||||||
|
return d2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
void
|
||||||
|
NGBTree<N,ValType>::recursiveNearest(
|
||||||
|
NGBTreeNode<N,ValType> *node,
|
||||||
|
int level,
|
||||||
|
const coords& x,
|
||||||
|
CoordType& R2,
|
||||||
|
NGBCell<N,ValType> *& best)
|
||||||
|
{
|
||||||
|
CoordType d2 = 0;
|
||||||
|
int axis = level % N;
|
||||||
|
NGBTreeNode<N,ValType> *other, *go;
|
||||||
|
|
||||||
|
if (x[axis] < node->value->coord[axis])
|
||||||
|
{
|
||||||
|
// The best is potentially in 0.
|
||||||
|
go = node->children[0];
|
||||||
|
other = node->children[1];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// If not it is in 1.
|
||||||
|
go = node->children[1];
|
||||||
|
other = node->children[0];
|
||||||
|
if (go == 0)
|
||||||
|
{
|
||||||
|
go = other;
|
||||||
|
other = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (go != 0)
|
||||||
|
{
|
||||||
|
recursiveNearest(go, level+1,
|
||||||
|
x, R2,best);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
CoordType thisR2 = computeDistance(node->value, x);
|
||||||
|
if (thisR2 < R2)
|
||||||
|
{
|
||||||
|
R2 = thisR2;
|
||||||
|
best = node->value;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if current node is not the nearest
|
||||||
|
CoordType thisR2 =
|
||||||
|
computeDistance(node->value, x);
|
||||||
|
|
||||||
|
if (thisR2 < R2)
|
||||||
|
{
|
||||||
|
R2 = thisR2;
|
||||||
|
best = node->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we found the best. We check whether the hypersphere
|
||||||
|
// intersect the hyperplane of the other branch
|
||||||
|
|
||||||
|
CoordType delta1;
|
||||||
|
|
||||||
|
delta1 = x[axis]-node->value->coord[axis];
|
||||||
|
if (delta1*delta1 < R2)
|
||||||
|
{
|
||||||
|
// The hypersphere intersects the hyperplane. Try the
|
||||||
|
// other branch
|
||||||
|
if (other != 0)
|
||||||
|
{
|
||||||
|
recursiveNearest(other, level+1, x, R2, best);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
NGBCell<N, ValType> *
|
||||||
|
NGBTree<N,ValType>::getNearestNeighbour(const coords& x)
|
||||||
|
{
|
||||||
|
CoordType R2 = INFINITY;
|
||||||
|
NGBCell<N,ValType> *best = 0;
|
||||||
|
|
||||||
|
recursiveNearest(root, 0, x, R2, best);
|
||||||
|
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int N, typename ValType>
|
||||||
|
void NGBTree<N,ValType>::getNearestNeighbours(const coords& x, uint32_t N2,
|
||||||
|
NGBCell<N, ValType> **cells)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
Loading…
Reference in New Issue
Block a user