Reformat and add some adjoint gradient
This commit is contained in:
parent
fe06434619
commit
046e9a1447
3 changed files with 61 additions and 17 deletions
|
@ -1,5 +1,5 @@
|
|||
/*+
|
||||
This is CosmoTool (./src/mykdtree.hpp) -- Copyright (C) Guilhem Lavaux (2007-2014)
|
||||
This is CosmoTool (./src/mykdtree.hpp) -- Copyright (C) Guilhem Lavaux (2007-2022)
|
||||
|
||||
guilhem.lavaux@gmail.com
|
||||
|
||||
|
@ -7,16 +7,16 @@ This software is a computer program whose purpose is to provide a toolbox for co
|
|||
data analysis (e.g. filters, generalized Fourier transforms, power spectra, ...)
|
||||
|
||||
This software is governed by the CeCILL license under French law and
|
||||
abiding by the rules of distribution of free software. You can use,
|
||||
abiding by the rules of distribution of free software. You can use,
|
||||
modify and/ or redistribute the software under the terms of the CeCILL
|
||||
license as circulated by CEA, CNRS and INRIA at the following URL
|
||||
"http://www.cecill.info".
|
||||
"http://www.cecill.info".
|
||||
|
||||
As a counterpart to the access to the source code and rights to copy,
|
||||
modify and redistribute granted by the license, users are provided only
|
||||
with a limited warranty and the software's author, the holder of the
|
||||
economic rights, and the successive licensors have only limited
|
||||
liability.
|
||||
liability.
|
||||
|
||||
In this respect, the user's attention is drawn to the risks associated
|
||||
with loading, using, modifying and/or developing or reproducing the
|
||||
|
@ -25,9 +25,9 @@ that may mean that it is complicated to manipulate, and that also
|
|||
therefore means that it is reserved for developers and experienced
|
||||
professionals having in-depth computer knowledge. Users are therefore
|
||||
encouraged to load and test the software's suitability as regards their
|
||||
requirements in conditions enabling the security of their systems and/or
|
||||
data to be ensured and, more generally, to use and operate it in the
|
||||
same conditions as regards security.
|
||||
requirements in conditions enabling the security of their systems and/or
|
||||
data to be ensured and, more generally, to use and operate it in the
|
||||
same conditions as regards security.
|
||||
|
||||
The fact that you are presently reading this means that you have had
|
||||
knowledge of the CeCILL license and that you accept its terms.
|
||||
|
@ -48,13 +48,13 @@ namespace CosmoTool {
|
|||
|
||||
typedef uint64_t NodeIntType;
|
||||
|
||||
template<int N, typename CType = ComputePrecision>
|
||||
template<int N, typename CType = ComputePrecision>
|
||||
struct KDDef
|
||||
{
|
||||
typedef CType CoordType;
|
||||
typedef float KDCoordinates[N];
|
||||
};
|
||||
|
||||
|
||||
template<int N, typename ValType, typename CType = ComputePrecision>
|
||||
struct KDCell
|
||||
{
|
||||
|
@ -102,7 +102,7 @@ namespace CosmoTool {
|
|||
uint64_t currentRank;
|
||||
uint64_t numCells;
|
||||
};
|
||||
|
||||
|
||||
|
||||
template<int N, typename ValType, typename CType = ComputePrecision>
|
||||
class RecursionMultipleInfo
|
||||
|
@ -121,7 +121,7 @@ namespace CosmoTool {
|
|||
}
|
||||
};
|
||||
|
||||
template<int N, typename ValType, typename CType = ComputePrecision>
|
||||
template<int N, typename ValType, typename CType = ComputePrecision>
|
||||
struct KD_default_cell_splitter
|
||||
{
|
||||
void operator()(KDCell<N,ValType,CType> **cells, NodeIntType Ncells, NodeIntType& split_index, int axis, typename KDDef<N,CType>::KDCoordinates minBound, typename KDDef<N,CType>::KDCoordinates maxBound);
|
||||
|
@ -135,7 +135,7 @@ namespace CosmoTool {
|
|||
typedef typename KDDef<N>::KDCoordinates coords;
|
||||
typedef KDCell<N,ValType,CType> Cell;
|
||||
typedef KDTreeNode<N,ValType,CType> Node;
|
||||
|
||||
|
||||
CellSplitter splitter;
|
||||
|
||||
KDTree(Cell *cells, NodeIntType Ncells);
|
||||
|
@ -153,10 +153,10 @@ namespace CosmoTool {
|
|||
std::copy(replicate, replicate+N, this->replicate);
|
||||
}
|
||||
|
||||
uint64_t getIntersection(const coords& x, CoordType r,
|
||||
uint64_t getIntersection(const coords& x, CoordType r,
|
||||
Cell **cells,
|
||||
uint64_t numCells);
|
||||
uint64_t getIntersection(const coords& x, CoordType r,
|
||||
uint64_t getIntersection(const coords& x, CoordType r,
|
||||
Cell **cells,
|
||||
CoordType *distances,
|
||||
uint64_t numCells);
|
||||
|
@ -183,7 +183,7 @@ namespace CosmoTool {
|
|||
NodeIntType getNumberInNode(const Node *n) const { return n->numNodes; }
|
||||
#else
|
||||
NodeIntType getNumberInNode(const Node *n) const {
|
||||
if (n == 0)
|
||||
if (n == 0)
|
||||
return 0;
|
||||
return 1+getNumberInNode(n->children[0])+getNumberInNode(n->children[1]);
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ namespace CosmoTool {
|
|||
uint32_t depth,
|
||||
coords minBound,
|
||||
coords maxBound);
|
||||
|
||||
|
||||
template<bool justCount>
|
||||
void recursiveIntersectionCells(RecursionInfoCells<N,ValType, CType>& info,
|
||||
Node *node,
|
||||
|
@ -224,7 +224,7 @@ namespace CosmoTool {
|
|||
CoordType& R2,
|
||||
Cell*& cell);
|
||||
void recursiveMultipleNearest(RecursionMultipleInfo<N,ValType,CType>& info, Node *node,
|
||||
int level);
|
||||
int level);
|
||||
|
||||
};
|
||||
|
||||
|
|
|
@ -79,14 +79,30 @@ namespace CosmoTool {
|
|||
return internal.currentCenter;
|
||||
}
|
||||
|
||||
/** This is the pure SPH smoothing function. It does not reweight by the
|
||||
* value computed at each grid site.
|
||||
*/
|
||||
template <typename FuncT>
|
||||
ComputePrecision computeSmoothedValue(
|
||||
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
||||
|
||||
/** This is the weighted SPH smoothing function. It does reweight by the
|
||||
* value computed at each grid site. This ensures the total sum of the interpolated
|
||||
* quantity is preserved by interpolating to the target mesh.
|
||||
*/
|
||||
template <typename FuncT>
|
||||
ComputePrecision computeInterpolatedValue(
|
||||
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
||||
|
||||
/** This is the adjoint gradient of computeInterpolatedValue w.r.t. to the value
|
||||
* array. FuncT is expected to have the following prototype:
|
||||
* void((CellValue defined by the user), ComputePrecision weighted_ag_value)
|
||||
*/
|
||||
template <typename FuncT>
|
||||
void computeAdjointGradientSmoothedValue(
|
||||
const typename SPHTree::coords &c, ComputePrecision ag_value, FuncT fun,
|
||||
SPHState *state = 0);
|
||||
|
||||
ComputePrecision
|
||||
getMaxDistance(const typename SPHTree::coords &c, SPHNode *node) const;
|
||||
|
||||
|
|
|
@ -134,6 +134,34 @@ namespace CosmoTool {
|
|||
return 1.0;
|
||||
}
|
||||
|
||||
template <typename ValType, int Ndims>
|
||||
template <typename FuncT>
|
||||
void SPHSmooth<ValType, Ndims>::computeAdjointGradientSmoothedValue(
|
||||
const typename SPHTree::coords &c, ComputePrecision ag_value, FuncT fun,
|
||||
SPHState *state) {
|
||||
if (state == 0)
|
||||
state = &internal;
|
||||
|
||||
ComputePrecision outputValue = 0;
|
||||
ComputePrecision max_dist = 0;
|
||||
ComputePrecision weight = 0;
|
||||
|
||||
for (uint32_t i = 0; i < state->currentNgb; i++) {
|
||||
weight +=
|
||||
computeWValue(c, *state->ngb[i], state->distances[i], interpolateOne);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < state->currentNgb; i++) {
|
||||
auto &cell = *state->ngb[i];
|
||||
double partial_ag =
|
||||
computeWValue(
|
||||
c, cell, state->distances[i],
|
||||
[ag_value](ComputePrecision) { return ag_value; }) /
|
||||
weight;
|
||||
fun(cell.val.pValue, ag_value);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING ! Cell's weight must be 1 !!!
|
||||
template <typename ValType, int Ndims>
|
||||
template <typename FuncT>
|
||||
|
|
Loading…
Reference in a new issue