Reformat and add some adjoint gradient
This commit is contained in:
parent
fe06434619
commit
046e9a1447
@ -1,5 +1,5 @@
|
|||||||
/*+
|
/*+
|
||||||
This is CosmoTool (./src/mykdtree.hpp) -- Copyright (C) Guilhem Lavaux (2007-2014)
|
This is CosmoTool (./src/mykdtree.hpp) -- Copyright (C) Guilhem Lavaux (2007-2022)
|
||||||
|
|
||||||
guilhem.lavaux@gmail.com
|
guilhem.lavaux@gmail.com
|
||||||
|
|
||||||
@ -7,16 +7,16 @@ This software is a computer program whose purpose is to provide a toolbox for co
|
|||||||
data analysis (e.g. filters, generalized Fourier transforms, power spectra, ...)
|
data analysis (e.g. filters, generalized Fourier transforms, power spectra, ...)
|
||||||
|
|
||||||
This software is governed by the CeCILL license under French law and
|
This software is governed by the CeCILL license under French law and
|
||||||
abiding by the rules of distribution of free software. You can use,
|
abiding by the rules of distribution of free software. You can use,
|
||||||
modify and/ or redistribute the software under the terms of the CeCILL
|
modify and/ or redistribute the software under the terms of the CeCILL
|
||||||
license as circulated by CEA, CNRS and INRIA at the following URL
|
license as circulated by CEA, CNRS and INRIA at the following URL
|
||||||
"http://www.cecill.info".
|
"http://www.cecill.info".
|
||||||
|
|
||||||
As a counterpart to the access to the source code and rights to copy,
|
As a counterpart to the access to the source code and rights to copy,
|
||||||
modify and redistribute granted by the license, users are provided only
|
modify and redistribute granted by the license, users are provided only
|
||||||
with a limited warranty and the software's author, the holder of the
|
with a limited warranty and the software's author, the holder of the
|
||||||
economic rights, and the successive licensors have only limited
|
economic rights, and the successive licensors have only limited
|
||||||
liability.
|
liability.
|
||||||
|
|
||||||
In this respect, the user's attention is drawn to the risks associated
|
In this respect, the user's attention is drawn to the risks associated
|
||||||
with loading, using, modifying and/or developing or reproducing the
|
with loading, using, modifying and/or developing or reproducing the
|
||||||
@ -25,9 +25,9 @@ that may mean that it is complicated to manipulate, and that also
|
|||||||
therefore means that it is reserved for developers and experienced
|
therefore means that it is reserved for developers and experienced
|
||||||
professionals having in-depth computer knowledge. Users are therefore
|
professionals having in-depth computer knowledge. Users are therefore
|
||||||
encouraged to load and test the software's suitability as regards their
|
encouraged to load and test the software's suitability as regards their
|
||||||
requirements in conditions enabling the security of their systems and/or
|
requirements in conditions enabling the security of their systems and/or
|
||||||
data to be ensured and, more generally, to use and operate it in the
|
data to be ensured and, more generally, to use and operate it in the
|
||||||
same conditions as regards security.
|
same conditions as regards security.
|
||||||
|
|
||||||
The fact that you are presently reading this means that you have had
|
The fact that you are presently reading this means that you have had
|
||||||
knowledge of the CeCILL license and that you accept its terms.
|
knowledge of the CeCILL license and that you accept its terms.
|
||||||
@ -48,13 +48,13 @@ namespace CosmoTool {
|
|||||||
|
|
||||||
typedef uint64_t NodeIntType;
|
typedef uint64_t NodeIntType;
|
||||||
|
|
||||||
template<int N, typename CType = ComputePrecision>
|
template<int N, typename CType = ComputePrecision>
|
||||||
struct KDDef
|
struct KDDef
|
||||||
{
|
{
|
||||||
typedef CType CoordType;
|
typedef CType CoordType;
|
||||||
typedef float KDCoordinates[N];
|
typedef float KDCoordinates[N];
|
||||||
};
|
};
|
||||||
|
|
||||||
template<int N, typename ValType, typename CType = ComputePrecision>
|
template<int N, typename ValType, typename CType = ComputePrecision>
|
||||||
struct KDCell
|
struct KDCell
|
||||||
{
|
{
|
||||||
@ -102,7 +102,7 @@ namespace CosmoTool {
|
|||||||
uint64_t currentRank;
|
uint64_t currentRank;
|
||||||
uint64_t numCells;
|
uint64_t numCells;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template<int N, typename ValType, typename CType = ComputePrecision>
|
template<int N, typename ValType, typename CType = ComputePrecision>
|
||||||
class RecursionMultipleInfo
|
class RecursionMultipleInfo
|
||||||
@ -121,7 +121,7 @@ namespace CosmoTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<int N, typename ValType, typename CType = ComputePrecision>
|
template<int N, typename ValType, typename CType = ComputePrecision>
|
||||||
struct KD_default_cell_splitter
|
struct KD_default_cell_splitter
|
||||||
{
|
{
|
||||||
void operator()(KDCell<N,ValType,CType> **cells, NodeIntType Ncells, NodeIntType& split_index, int axis, typename KDDef<N,CType>::KDCoordinates minBound, typename KDDef<N,CType>::KDCoordinates maxBound);
|
void operator()(KDCell<N,ValType,CType> **cells, NodeIntType Ncells, NodeIntType& split_index, int axis, typename KDDef<N,CType>::KDCoordinates minBound, typename KDDef<N,CType>::KDCoordinates maxBound);
|
||||||
@ -135,7 +135,7 @@ namespace CosmoTool {
|
|||||||
typedef typename KDDef<N>::KDCoordinates coords;
|
typedef typename KDDef<N>::KDCoordinates coords;
|
||||||
typedef KDCell<N,ValType,CType> Cell;
|
typedef KDCell<N,ValType,CType> Cell;
|
||||||
typedef KDTreeNode<N,ValType,CType> Node;
|
typedef KDTreeNode<N,ValType,CType> Node;
|
||||||
|
|
||||||
CellSplitter splitter;
|
CellSplitter splitter;
|
||||||
|
|
||||||
KDTree(Cell *cells, NodeIntType Ncells);
|
KDTree(Cell *cells, NodeIntType Ncells);
|
||||||
@ -153,10 +153,10 @@ namespace CosmoTool {
|
|||||||
std::copy(replicate, replicate+N, this->replicate);
|
std::copy(replicate, replicate+N, this->replicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t getIntersection(const coords& x, CoordType r,
|
uint64_t getIntersection(const coords& x, CoordType r,
|
||||||
Cell **cells,
|
Cell **cells,
|
||||||
uint64_t numCells);
|
uint64_t numCells);
|
||||||
uint64_t getIntersection(const coords& x, CoordType r,
|
uint64_t getIntersection(const coords& x, CoordType r,
|
||||||
Cell **cells,
|
Cell **cells,
|
||||||
CoordType *distances,
|
CoordType *distances,
|
||||||
uint64_t numCells);
|
uint64_t numCells);
|
||||||
@ -183,7 +183,7 @@ namespace CosmoTool {
|
|||||||
NodeIntType getNumberInNode(const Node *n) const { return n->numNodes; }
|
NodeIntType getNumberInNode(const Node *n) const { return n->numNodes; }
|
||||||
#else
|
#else
|
||||||
NodeIntType getNumberInNode(const Node *n) const {
|
NodeIntType getNumberInNode(const Node *n) const {
|
||||||
if (n == 0)
|
if (n == 0)
|
||||||
return 0;
|
return 0;
|
||||||
return 1+getNumberInNode(n->children[0])+getNumberInNode(n->children[1]);
|
return 1+getNumberInNode(n->children[0])+getNumberInNode(n->children[1]);
|
||||||
}
|
}
|
||||||
@ -211,7 +211,7 @@ namespace CosmoTool {
|
|||||||
uint32_t depth,
|
uint32_t depth,
|
||||||
coords minBound,
|
coords minBound,
|
||||||
coords maxBound);
|
coords maxBound);
|
||||||
|
|
||||||
template<bool justCount>
|
template<bool justCount>
|
||||||
void recursiveIntersectionCells(RecursionInfoCells<N,ValType, CType>& info,
|
void recursiveIntersectionCells(RecursionInfoCells<N,ValType, CType>& info,
|
||||||
Node *node,
|
Node *node,
|
||||||
@ -224,7 +224,7 @@ namespace CosmoTool {
|
|||||||
CoordType& R2,
|
CoordType& R2,
|
||||||
Cell*& cell);
|
Cell*& cell);
|
||||||
void recursiveMultipleNearest(RecursionMultipleInfo<N,ValType,CType>& info, Node *node,
|
void recursiveMultipleNearest(RecursionMultipleInfo<N,ValType,CType>& info, Node *node,
|
||||||
int level);
|
int level);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -79,14 +79,30 @@ namespace CosmoTool {
|
|||||||
return internal.currentCenter;
|
return internal.currentCenter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** This is the pure SPH smoothing function. It does not reweight by the
|
||||||
|
* value computed at each grid site.
|
||||||
|
*/
|
||||||
template <typename FuncT>
|
template <typename FuncT>
|
||||||
ComputePrecision computeSmoothedValue(
|
ComputePrecision computeSmoothedValue(
|
||||||
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
||||||
|
|
||||||
|
/** This is the weighted SPH smoothing function. It does reweight by the
|
||||||
|
* value computed at each grid site. This ensures the total sum of the interpolated
|
||||||
|
* quantity is preserved by interpolating to the target mesh.
|
||||||
|
*/
|
||||||
template <typename FuncT>
|
template <typename FuncT>
|
||||||
ComputePrecision computeInterpolatedValue(
|
ComputePrecision computeInterpolatedValue(
|
||||||
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
||||||
|
|
||||||
|
/** This is the adjoint gradient of computeInterpolatedValue w.r.t. to the value
|
||||||
|
* array. FuncT is expected to have the following prototype:
|
||||||
|
* void((CellValue defined by the user), ComputePrecision weighted_ag_value)
|
||||||
|
*/
|
||||||
|
template <typename FuncT>
|
||||||
|
void computeAdjointGradientSmoothedValue(
|
||||||
|
const typename SPHTree::coords &c, ComputePrecision ag_value, FuncT fun,
|
||||||
|
SPHState *state = 0);
|
||||||
|
|
||||||
ComputePrecision
|
ComputePrecision
|
||||||
getMaxDistance(const typename SPHTree::coords &c, SPHNode *node) const;
|
getMaxDistance(const typename SPHTree::coords &c, SPHNode *node) const;
|
||||||
|
|
||||||
|
@ -134,6 +134,34 @@ namespace CosmoTool {
|
|||||||
return 1.0;
|
return 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ValType, int Ndims>
|
||||||
|
template <typename FuncT>
|
||||||
|
void SPHSmooth<ValType, Ndims>::computeAdjointGradientSmoothedValue(
|
||||||
|
const typename SPHTree::coords &c, ComputePrecision ag_value, FuncT fun,
|
||||||
|
SPHState *state) {
|
||||||
|
if (state == 0)
|
||||||
|
state = &internal;
|
||||||
|
|
||||||
|
ComputePrecision outputValue = 0;
|
||||||
|
ComputePrecision max_dist = 0;
|
||||||
|
ComputePrecision weight = 0;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < state->currentNgb; i++) {
|
||||||
|
weight +=
|
||||||
|
computeWValue(c, *state->ngb[i], state->distances[i], interpolateOne);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < state->currentNgb; i++) {
|
||||||
|
auto &cell = *state->ngb[i];
|
||||||
|
double partial_ag =
|
||||||
|
computeWValue(
|
||||||
|
c, cell, state->distances[i],
|
||||||
|
[ag_value](ComputePrecision) { return ag_value; }) /
|
||||||
|
weight;
|
||||||
|
fun(cell.val.pValue, ag_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WARNING ! Cell's weight must be 1 !!!
|
// WARNING ! Cell's weight must be 1 !!!
|
||||||
template <typename ValType, int Ndims>
|
template <typename ValType, int Ndims>
|
||||||
template <typename FuncT>
|
template <typename FuncT>
|
||||||
|
Loading…
Reference in New Issue
Block a user