cosmotool/src/mykdtree.tcc

373 lines
8.5 KiB
C++

#include <cstring>
#include <algorithm>
#include <limits>
#include <iostream>
namespace CosmoTool {
template<int N, typename ValType>
class CellCompare
{
public:
CellCompare(int k)
{
rank = k;
}
bool operator()(const NGBCell<N,ValType> *a, const NGBCell<N,ValType> *b) const
{
return (a->coord[rank] < b->coord[rank]);
}
protected:
int rank;
};
template<int N, typename ValType>
NGBTree<N,ValType>::~NGBTree()
{
}
template<int N, typename ValType>
NGBTree<N,ValType>::NGBTree(NGBCell<N,ValType> *cells, uint32_t Ncells)
{
numNodes = Ncells;
nodes = new NGBTreeNode<N,ValType>[numNodes];
sortingHelper = new NGBCell<N,ValType> *[Ncells];
for (uint32_t i = 0; i < Ncells; i++)
sortingHelper[i] = &cells[i];
optimize();
}
template<int N, typename ValType>
void NGBTree<N,ValType>::optimize()
{
coords absoluteMin, absoluteMax;
std::cout << "Optimizing the tree..." << std::endl;
uint32_t activeCells = gatherActiveCells(sortingHelper, numNodes);
std::cout << " number of active cells = " << activeCells << std::endl;
lastNode = 0;
for (int i = 0; i < N; i++)
{
absoluteMin[i] = -std::numeric_limits<typeof (absoluteMin[0])>::max();
absoluteMax[i] = std::numeric_limits<typeof (absoluteMax[0])>::max();
}
std::cout << " rebuilding the tree..." << std::endl;
root = buildTree(sortingHelper, activeCells, 0, absoluteMin, absoluteMax);
std::cout << " done." << std::endl;
}
template<int N, typename ValType>
uint32_t NGBTree<N,ValType>::getIntersection(const coords& x, CoordType r,
NGBCell<N, ValType> **cells,
uint32_t numCells)
throw (NotEnoughCells)
{
RecursionInfoCells<N,ValType> info;
memcpy(info.x, x, sizeof(x));
info.r = r;
info.r2 = r*r;
info.cells = cells;
info.currentRank = 0;
info.numCells = numCells;
recursiveIntersectionCells(info, root, 0);
return info.currentRank;
}
template<int N, typename ValType>
void NGBTree<N,ValType>::recursiveIntersectionCells(RecursionInfoCells<N,ValType>& info,
NGBTreeNode<N,ValType> *node,
int level)
throw (NotEnoughCells)
{
int axis = level % N;
CoordType d2 = 0;
if (node->value->active)
{
for (int j = 0; j < 3; j++)
{
CoordType delta = info.x[j]-node->value->coord[j];
d2 += delta*delta;
}
if (d2 < info.r2)
{
if (info.currentRank == info.numCells)
throw NotEnoughCells();
info.cells[info.currentRank++] = node->value;
}
}
// The hypersphere intersects the left child node
if (((info.x[axis]+info.r) > node->minBound[axis]) &&
((info.x[axis]-info.r) < node->value->coord[axis]))
{
if (node->children[0] != 0)
recursiveIntersectionCells(info, node->children[0],
level+1);
}
if (((info.x[axis]+info.r) > node->value->coord[axis]) &&
((info.x[axis]-info.r) < node->maxBound[axis]))
{
if (node->children[1] != 0)
recursiveIntersectionCells(info, node->children[1],
level+1);
}
}
template<int N, typename ValType>
uint32_t gatherActiveCells(NGBCell<N,ValType> **cells,
uint32_t Ncells)
{
uint32_t swapId = Ncells-1;
uint32_t i = 0;
while (!cells[swapId]->active && swapId > 0)
swapId--;
while (i < swapId)
{
if (!cells[i]->active)
{
std::swap(cells[i], cells[swapId]);
while (!cells[swapId]->active && swapId > i)
{
swapId--;
}
}
i++;
}
return swapId+1;
}
template<int N, typename ValType>
NGBTreeNode<N, ValType> *NGBTree<N,ValType>::buildTree(NGBCell<N,ValType> **cell0,
uint32_t Ncells,
uint32_t depth,
coords minBound,
coords maxBound)
{
if (Ncells == 0)
return 0;
int axis = depth % N;
NGBTreeNode<N,ValType> *node = &nodes[lastNode++];
uint32_t mid = Ncells/2;
coords tmpBound;
// Isolate the environment
{
CellCompare<N,ValType> compare(axis);
std::sort(cell0, cell0+Ncells, compare);
}
node->value = *(cell0+mid);
memcpy(&node->minBound[0], &minBound[0], sizeof(coords));
memcpy(&node->maxBound[0], &maxBound[0], sizeof(coords));
memcpy(tmpBound, maxBound, sizeof(coords));
tmpBound[axis] = node->value->coord[axis];
depth++;
node->children[0] = buildTree(cell0, mid, depth, minBound, tmpBound);
memcpy(tmpBound, minBound, sizeof(coords));
tmpBound[axis] = node->value->coord[axis];
node->children[1] = buildTree(cell0+mid+1, Ncells-mid-1, depth,
tmpBound, maxBound);
return node;
}
template<int N, typename ValType>
uint32_t NGBTree<N,ValType>::countActives() const
{
uint32_t numActive = 0;
for (uint32_t i = 0; i < lastNode; i++)
{
if (nodes[i].value->active)
numActive++;
}
return numActive;
}
template<int N, typename ValType>
typename NGBDef<N>::CoordType
NGBTree<N,ValType>::computeDistance(NGBCell<N, ValType> *cell, const coords& x)
{
CoordType d2 = 0;
for (int i = 0; i < N; i++)
{
CoordType delta = cell->coord[i] - x[i];
d2 += delta*delta;
}
return d2;
}
template<int N, typename ValType>
void
NGBTree<N,ValType>::recursiveNearest(
NGBTreeNode<N,ValType> *node,
int level,
const coords& x,
CoordType& R2,
NGBCell<N,ValType> *& best)
{
CoordType d2 = 0;
int axis = level % N;
NGBTreeNode<N,ValType> *other, *go;
if (x[axis] < node->value->coord[axis])
{
// The best is potentially in 0.
go = node->children[0];
other = node->children[1];
}
else
{
// If not it is in 1.
go = node->children[1];
other = node->children[0];
if (go == 0)
{
go = other;
other = 0;
}
}
if (go != 0)
{
recursiveNearest(go, level+1,
x, R2,best);
}
else
{
CoordType thisR2 = computeDistance(node->value, x);
if (thisR2 < R2)
{
R2 = thisR2;
best = node->value;
}
return;
}
// Check if current node is not the nearest
CoordType thisR2 =
computeDistance(node->value, x);
if (thisR2 < R2)
{
R2 = thisR2;
best = node->value;
}
// Now we found the best. We check whether the hypersphere
// intersect the hyperplane of the other branch
CoordType delta1;
delta1 = x[axis]-node->value->coord[axis];
if (delta1*delta1 < R2)
{
// The hypersphere intersects the hyperplane. Try the
// other branch
if (other != 0)
{
recursiveNearest(other, level+1, x, R2, best);
}
}
}
template<int N, typename ValType>
NGBCell<N, ValType> *
NGBTree<N,ValType>::getNearestNeighbour(const coords& x)
{
CoordType R2 = INFINITY;
NGBCell<N,ValType> *best = 0;
recursiveNearest(root, 0, x, R2, best);
return best;
}
template<int N, typename ValType>
void
NGBTree<N,ValType>::recursiveMultipleNearest(RecursionMultipleInfo<N,ValType>& info, NGBTreeNode<N,ValType> *node,
int level)
{
CoordType d2 = 0;
int axis = level % N;
NGBTreeNode<N,ValType> *other, *go;
if (info.x[axis] < node->value->coord[axis])
{
// The best is potentially in 0.
go = node->children[0];
other = node->children[1];
}
else
{
// If not it is in 1.
go = node->children[1];
other = node->children[0];
if (go == 0)
{
go = other;
other = 0;
}
}
if (go != 0)
{
recursiveMultipleNearest(info, go, level+1);
}
// Check if current node is not the nearest
CoordType thisR2 =
computeDistance(node->value, info.x);
info.queue.push(node->value, thisR2);
info.traversed++;
if (go == 0)
return;
// Now we found the best. We check whether the hypersphere
// intersect the hyperplane of the other branch
CoordType delta1;
delta1 = info.x[axis]-node->value->coord[axis];
if (delta1*delta1 < info.queue.getMaxPriority())
{
// The hypersphere intersects the hyperplane. Try the
// other branch
if (other != 0)
{
recursiveMultipleNearest(info, other, level+1);
}
}
}
template<int N, typename ValType>
void NGBTree<N,ValType>::getNearestNeighbours(const coords& x, uint32_t N2,
NGBCell<N, ValType> **cells)
{
RecursionMultipleInfo<N,ValType> info(x, cells, N2);
for (int i = 0; i < N2; i++)
cells[i] = 0;
recursiveMultipleNearest(info, root, 0);
std::cout << "Traversed = " << info.traversed << std::endl;
}
};