Reformat and add some adjoint gradient
This commit is contained in:
parent
fe06434619
commit
046e9a1447
@ -1,5 +1,5 @@
|
||||
/*+
|
||||
This is CosmoTool (./src/mykdtree.hpp) -- Copyright (C) Guilhem Lavaux (2007-2014)
|
||||
This is CosmoTool (./src/mykdtree.hpp) -- Copyright (C) Guilhem Lavaux (2007-2022)
|
||||
|
||||
guilhem.lavaux@gmail.com
|
||||
|
||||
|
@ -79,14 +79,30 @@ namespace CosmoTool {
|
||||
return internal.currentCenter;
|
||||
}
|
||||
|
||||
/** This is the pure SPH smoothing function. It does not reweight by the
|
||||
* value computed at each grid site.
|
||||
*/
|
||||
template <typename FuncT>
|
||||
ComputePrecision computeSmoothedValue(
|
||||
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
||||
|
||||
/** This is the weighted SPH smoothing function. It does reweight by the
|
||||
* value computed at each grid site. This ensures the total sum of the interpolated
|
||||
* quantity is preserved by interpolating to the target mesh.
|
||||
*/
|
||||
template <typename FuncT>
|
||||
ComputePrecision computeInterpolatedValue(
|
||||
const typename SPHTree::coords &c, FuncT fun, SPHState *state = 0);
|
||||
|
||||
/** This is the adjoint gradient of computeInterpolatedValue w.r.t. to the value
|
||||
* array. FuncT is expected to have the following prototype:
|
||||
* void((CellValue defined by the user), ComputePrecision weighted_ag_value)
|
||||
*/
|
||||
template <typename FuncT>
|
||||
void computeAdjointGradientSmoothedValue(
|
||||
const typename SPHTree::coords &c, ComputePrecision ag_value, FuncT fun,
|
||||
SPHState *state = 0);
|
||||
|
||||
ComputePrecision
|
||||
getMaxDistance(const typename SPHTree::coords &c, SPHNode *node) const;
|
||||
|
||||
|
@ -134,6 +134,34 @@ namespace CosmoTool {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
template <typename ValType, int Ndims>
|
||||
template <typename FuncT>
|
||||
void SPHSmooth<ValType, Ndims>::computeAdjointGradientSmoothedValue(
|
||||
const typename SPHTree::coords &c, ComputePrecision ag_value, FuncT fun,
|
||||
SPHState *state) {
|
||||
if (state == 0)
|
||||
state = &internal;
|
||||
|
||||
ComputePrecision outputValue = 0;
|
||||
ComputePrecision max_dist = 0;
|
||||
ComputePrecision weight = 0;
|
||||
|
||||
for (uint32_t i = 0; i < state->currentNgb; i++) {
|
||||
weight +=
|
||||
computeWValue(c, *state->ngb[i], state->distances[i], interpolateOne);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < state->currentNgb; i++) {
|
||||
auto &cell = *state->ngb[i];
|
||||
double partial_ag =
|
||||
computeWValue(
|
||||
c, cell, state->distances[i],
|
||||
[ag_value](ComputePrecision) { return ag_value; }) /
|
||||
weight;
|
||||
fun(cell.val.pValue, ag_value);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING ! Cell's weight must be 1 !!!
|
||||
template <typename ValType, int Ndims>
|
||||
template <typename FuncT>
|
||||
|
Loading…
Reference in New Issue
Block a user