Fixed KDTree recursion and splitting technique

This commit is contained in:
Guilhem Lavaux 2012-05-21 14:04:21 -04:00
parent d4a14d7d85
commit 80951b4b52
3 changed files with 113 additions and 33 deletions

View File

@ -7,13 +7,14 @@
#include "mykdtree.hpp" #include "mykdtree.hpp"
#include "kdtree_splitters.hpp" #include "kdtree_splitters.hpp"
#define NTRY 100 #define NTRY 10
#define ND 3 #define ND 3
using namespace std; using namespace std;
using namespace CosmoTool; using namespace CosmoTool;
typedef KDTree<ND,char,ComputePrecision,KD_homogeneous_cell_splitter<ND, char> > MyTree; typedef KDTree<ND,char,ComputePrecision,KD_homogeneous_cell_splitter<ND, char> > MyTree;
//typedef KDTree<ND,char,ComputePrecision > MyTree;
typedef KDCell<ND,char> MyCell; typedef KDCell<ND,char> MyCell;
MyCell *findNearest(MyTree::coords& xc, MyCell *cells, uint32_t Ncells) MyCell *findNearest(MyTree::coords& xc, MyCell *cells, uint32_t Ncells)
@ -39,7 +40,7 @@ MyCell *findNearest(MyTree::coords& xc, MyCell *cells, uint32_t Ncells)
int main() int main()
{ {
uint32_t Ncells = 100000; uint32_t Ncells = 10000000;
MyCell *cells = new MyCell[Ncells]; MyCell *cells = new MyCell[Ncells];
for (int i = 0; i < Ncells; i++) for (int i = 0; i < Ncells; i++)
@ -49,7 +50,15 @@ int main()
cells[i].coord[l] = drand48(); cells[i].coord[l] = drand48();
} }
// Check timing
clock_t startTimer = clock();
MyTree tree(cells, Ncells); MyTree tree(cells, Ncells);
clock_t endTimer = clock();
clock_t delta = endTimer-startTimer;
double myTime = delta*1.0/CLOCKS_PER_SEC * 1.0;
cout << "KDTree build = " << myTime << " s" << endl;
MyTree::coords *xc = new MyTree::coords[NTRY]; MyTree::coords *xc = new MyTree::coords[NTRY];
@ -69,31 +78,39 @@ int main()
for (int k = 0; k < NTRY; k++) { for (int k = 0; k < NTRY; k++) {
cout << "Seed = " << xc[k][0] << " " << xc[k][1] << " " << xc[k][2] << endl; cout << "Seed = " << xc[k][0] << " " << xc[k][1] << " " << xc[k][2] << endl;
tree.getNearestNeighbours(xc[k], 12, ngb, distances); tree.getNearestNeighbours(xc[k], 12, ngb, distances);
int last = -1;
for (uint32_t i = 0; i < 12; i++) for (uint32_t i = 0; i < 12; i++)
{ {
if (ngb[i] == 0)
continue;
last = i;
double d2 = 0; double d2 = 0;
for (int l = 0; l < 3; l++) for (int l = 0; l < 3; l++)
d2 += ({double delta = xc[k][l] - ngb[i]->coord[l]; delta*delta;}); d2 += ({double delta = xc[k][l] - ngb[i]->coord[l]; delta*delta;});
fngb << ngb[i]->coord[0] << " " << ngb[i]->coord[1] << " " << ngb[i]->coord[2] << " " << sqrt(d2) << endl; fngb << ngb[i]->coord[0] << " " << ngb[i]->coord[1] << " " << ngb[i]->coord[2] << " " << sqrt(d2) << endl;
} }
fngb << endl << endl; fngb << endl << endl;
double farther_dist = distances[11]; double farther_dist = distances[last];
for (uint32_t i = 0; i < Ncells; i++) for (uint32_t i = 0; i < Ncells; i++)
{ {
bool found = false; bool found = false;
// If the points is not in the list, it means it is farther than the farther point // If the points is not in the list, it means it is farther than the farthest point
for (int j =0; j < 12; j++) for (int j =0; j < 12; j++)
{ {
if (&cells[i] == ngb[j]) { if (&cells[i] == ngb[j]) {
found = true; found = true;
break; break;
} }
} }
double dist_to_seed = 0; double dist_to_seed = 0;
for (int l = 0; l < 3; l++) for (int l = 0; l < 3; l++)
{ double delta = xc[k][l]-cells[i].coord[l]; {
dist_to_seed += delta*delta; } double delta = xc[k][l]-cells[i].coord[l];
dist_to_seed += delta*delta;
}
if (!found) if (!found)
{ {
if (dist_to_seed <= farther_dist) if (dist_to_seed <= farther_dist)

View File

@ -12,56 +12,119 @@ namespace CosmoTool
typedef typename KDDef<N,CType>::KDCoordinates coords; typedef typename KDDef<N,CType>::KDCoordinates coords;
typedef typename KDDef<N,CType>::CoordType ctype; typedef typename KDDef<N,CType>::CoordType ctype;
void check_splitting(KDCell<N,ValType,CType> **cells, uint32_t Ncells, int axis, uint32_t split_index, ctype midCoord)
{
ctype delta = std::numeric_limits<ctype>::max();
assert(split_index < Ncells);
assert(axis < N);
for (uint32_t i = 0; i < split_index; i++)
{
assert(cells[i]->coord[axis] <= midCoord);
delta = min(midCoord-cells[i]->coord[axis], delta);
}
for (uint32_t i = split_index+1; i < Ncells; i++)
{
assert(cells[i]->coord[axis] > midCoord);
delta = min(cells[i]->coord[axis]-midCoord, delta);
}
assert(delta >= 0);
assert (std::abs(cells[split_index]->coord[axis]-midCoord) <= delta);
}
void operator()(KDCell<N,ValType,CType> **cells, uint32_t Ncells, uint32_t& split_index, int axis, coords minBound, coords maxBound) void operator()(KDCell<N,ValType,CType> **cells, uint32_t Ncells, uint32_t& split_index, int axis, coords minBound, coords maxBound)
{ {
if (Ncells == 1)
{
split_index = 0;
return;
}
ctype midCoord = 0.5*(maxBound[axis]+minBound[axis]); ctype midCoord = 0.5*(maxBound[axis]+minBound[axis]);
uint32_t below = 0, above = Ncells-1; uint32_t below = 0, above = Ncells-1;
ctype delta_max = std::abs(cells[0]->coord[axis]-midCoord); ctype delta_min = std::numeric_limits<ctype>::max();
uint32_t idx_max = 0; uint32_t idx_min = std::numeric_limits<uint32_t>::max();
while (below < above) while (below < above)
{ {
ctype delta = cells[below]->coord[axis]-midCoord; ctype delta = cells[below]->coord[axis]-midCoord;
if (delta > 0) if (delta > 0)
{ {
if (delta < delta_max) if (delta < delta_min)
{ {
delta_max = delta; delta_min = delta;
idx_max = above; idx_min = above;
} }
std::swap(cells[below], cells[above--]); std::swap(cells[below], cells[above--]);
} }
else else
{ {
if (-delta < delta_max) if (-delta < delta_min)
{ {
delta_max = -delta; delta_min = -delta;
idx_max = below; idx_min = below;
} }
below++; below++;
} }
} }
if (idx_max != above) // Last iteration
{
ctype delta = cells[below]->coord[axis]-midCoord;
if (delta > 0)
{
if (delta < delta_min)
{
delta_min = delta;
idx_min = above;
}
}
else
{
if (-delta < delta_min)
{
delta_min = -delta;
idx_min = above;
}
}
}
if (idx_min != above)
{ {
bool cond1 = cells[idx_max]->coord[axis] > midCoord; bool cond1 = cells[idx_min]->coord[axis] > midCoord;
bool cond2 = cells[above]->coord[axis] > midCoord; bool cond2 = cells[above]->coord[axis] > midCoord;
if ((cond1 && cond2) || (!cond1 && !cond2)) if ((cond1 && cond2) || (!cond1 && !cond2))
{ {
split_index = above; split_index = above;
std::swap(cells[above], cells[idx_max]); std::swap(cells[above], cells[idx_min]);
} }
else if (cond2) else if (cond2)
{ {
split_index = above-1; if (above >= 1)
std::swap(cells[above-1], cells[idx_max]); {
split_index = above-1;
std::swap(cells[above-1], cells[idx_min]);
}
else
split_index = 0;
assert(split_index >= 0);
} }
else else
{ {
split_index = above+1; if (above+1 < Ncells)
std::swap(cells[above+1], cells[idx_max]); {
split_index = above+1;
std::swap(cells[above+1], cells[idx_min]);
}
else
split_index = Ncells-1;
assert(split_index < Ncells);
} }
} }
else split_index = above; else split_index = above;
// check_splitting(cells, Ncells, axis, split_index, midCoord);
} }
}; };

View File

@ -390,11 +390,11 @@ namespace CosmoTool {
// If not it is in 1. // If not it is in 1.
go = node->children[1]; go = node->children[1];
other = node->children[0]; other = node->children[0];
if (go == 0) // if (go == 0)
{ // {
go = other; // go = other;
other = 0; //other = 0;
} //}
} }
if (go != 0) if (go != 0)
@ -407,8 +407,8 @@ namespace CosmoTool {
computeDistance(node->value, info.x); computeDistance(node->value, info.x);
info.queue.push(node->value, thisR2); info.queue.push(node->value, thisR2);
info.traversed++; info.traversed++;
if (go == 0) // if (go == 0)
return; // return;
// Now we found the best. We check whether the hypersphere // Now we found the best. We check whether the hypersphere
// intersect the hyperplane of the other branch // intersect the hyperplane of the other branch