Initial import
This commit is contained in:
commit
56a50eead3
820 changed files with 192077 additions and 0 deletions
513
libLSS/mcmc/global_state.hpp
Normal file
513
libLSS/mcmc/global_state.hpp
Normal file
|
@ -0,0 +1,513 @@
|
|||
/*+
|
||||
ARES/HADES/BORG Package -- ./libLSS/mcmc/global_state.hpp
|
||||
Copyright (C) 2014-2020 Guilhem Lavaux <guilhem.lavaux@iap.fr>
|
||||
Copyright (C) 2009-2020 Jens Jasche <jens.jasche@fysik.su.se>
|
||||
|
||||
Additional contributions from:
|
||||
Guilhem Lavaux <guilhem.lavaux@iap.fr> (2023)
|
||||
|
||||
+*/
|
||||
#ifndef _GLOBAL_STATE_HPP
|
||||
#define _GLOBAL_STATE_HPP
|
||||
|
||||
#include <boost/type_traits/is_base_of.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <typeindex>
|
||||
#include <algorithm>
|
||||
#include "libLSS/mpi/generic_mpi.hpp"
|
||||
#include "libLSS/tools/console.hpp"
|
||||
#include "libLSS/mcmc/state_element.hpp"
|
||||
|
||||
namespace LibLSS {
|
||||
|
||||
/**
|
||||
* @brief This is the class that manages the dictionnary that is saved in each MCMC/Restart file.
|
||||
*
|
||||
* It is *not* copy-constructible.
|
||||
*/
|
||||
class MarkovState {
|
||||
public:
|
||||
typedef std::map<std::string, bool> SaveMap;
|
||||
typedef std::map<std::string, StateElement *> StateMap;
|
||||
typedef std::map<std::string, std::type_index> TypeMap;
|
||||
typedef std::set<std::string> Requirements;
|
||||
|
||||
private:
|
||||
SaveMap save_map;
|
||||
StateMap state_map, toProcess;
|
||||
TypeMap type_map;
|
||||
std::list<std::tuple<Requirements, std::function<void()>>> postLoad;
|
||||
std::set<std::string> loaded;
|
||||
|
||||
public:
|
||||
MarkovState(MarkovState const &) = delete;
|
||||
|
||||
/**
|
||||
* @brief Construct a new empty Markov State object.
|
||||
*
|
||||
*/
|
||||
MarkovState() {}
|
||||
|
||||
/**
|
||||
* @brief Destroy the Markov State object.
|
||||
*
|
||||
* All the elements stored in the dictionnary will be destroyed, as the ownership
|
||||
* is given the dictionnary implicitly when the element is added to it.
|
||||
*/
|
||||
~MarkovState() {
|
||||
for (StateMap::iterator i = state_map.begin(); i != state_map.end();
|
||||
++i) {
|
||||
Console::instance().print<LOG_VERBOSE>(
|
||||
boost::format("Destroying %s") % i->first);
|
||||
delete i->second;
|
||||
}
|
||||
save_map.clear();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void check_class() {
|
||||
BOOST_MPL_ASSERT_MSG(
|
||||
(boost::is_base_of<StateElement, T>::value), T_is_not_a_StateElement,
|
||||
());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Function to access by its name an element stored in the dictionnary.
|
||||
*
|
||||
* This function makes a lookup and a dynamic cast to the specified template "StateElement".
|
||||
* It tries to find the indicated state element by name. If it fails an error is thrown.
|
||||
* A dynamic cast is then issued to ensure that the stored type is the same as the requested one.
|
||||
*
|
||||
* @tparam T type of the element, cast will be checked
|
||||
* @param name string id of the element
|
||||
* @return T* pointer to the element
|
||||
*/
|
||||
template <typename T>
|
||||
T *get(const std::string &name) {
|
||||
check_class<T>();
|
||||
StateMap::iterator i = state_map.find(name);
|
||||
if (i == state_map.end() || i->second == 0) {
|
||||
error_helper<ErrorBadState>(
|
||||
boost::format("Invalid access to %s") % name);
|
||||
}
|
||||
T *ptr = dynamic_cast<T *>(i->second);
|
||||
if (ptr == 0) {
|
||||
error_helper<ErrorBadCast>(
|
||||
boost::format("Bad cast in access to %s") % name);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Access using a boost::format object.
|
||||
*
|
||||
* @tparam T
|
||||
* @param f
|
||||
* @return T*
|
||||
*/
|
||||
template <typename T>
|
||||
T *get(const boost::format &f) {
|
||||
return get<T>(f.str());
|
||||
}
|
||||
|
||||
static void _format_expansion(boost::format &f) {}
|
||||
|
||||
template <typename A, typename... U>
|
||||
static void _format_expansion(boost::format &f, A &&a, U &&... u) {
|
||||
_format_expansion(f % a, u...);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
T *formatGet(std::string const &s, Args &&... args) {
|
||||
boost::format f(s);
|
||||
_format_expansion(f, std::forward<Args>(args)...);
|
||||
return get<T>(f);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T *get(const boost::format &f) const {
|
||||
return get<T>(f.str());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T *get(const std::string &name) const {
|
||||
check_class<T>();
|
||||
StateMap::const_iterator i = state_map.find(name);
|
||||
if (i == state_map.end() || i->second == 0) {
|
||||
error_helper<ErrorBadState>(
|
||||
boost::format("Invalid access to %s") % name);
|
||||
}
|
||||
|
||||
const T *ptr = dynamic_cast<const T *>(i->second);
|
||||
if (ptr == 0) {
|
||||
error_helper<ErrorBadCast>(
|
||||
boost::format("Bad cast in access to %s") % name);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check existence of an element in the dictionnary.
|
||||
*
|
||||
* @param name string id of the element
|
||||
* @return true if it exists
|
||||
* @return false if it does not exist
|
||||
*/
|
||||
bool exists(const std::string &name) const {
|
||||
return state_map.find(name) != state_map.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Access an element through operator [] overload.
|
||||
*
|
||||
* @param name
|
||||
* @return StateElement&
|
||||
*/
|
||||
StateElement &operator[](const std::string &name) {
|
||||
return *get<StateElement>(name);
|
||||
}
|
||||
|
||||
const StateElement &operator[](const std::string &name) const {
|
||||
return *get<StateElement>(name);
|
||||
}
|
||||
|
||||
std::type_index getStoredType(const std::string &name) const {
|
||||
auto iter = type_map.find(name);
|
||||
if (iter == type_map.end())
|
||||
error_helper<ErrorBadState>(
|
||||
"Unknown entry " + name + " during type query");
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add an element in the dictionnary.
|
||||
*
|
||||
* @param name string id of the new element
|
||||
* @param elt Object to add in the dictionnary. The ownership is transferred to MarkovState.
|
||||
* @param write_to_snapshot indicate, if true, that the element has to be written in mcmc files
|
||||
* @return StateElement* the same object as "elt", used to daisy chain calls.
|
||||
*/
|
||||
template <typename T>
|
||||
T *newElement(
|
||||
const std::string &name, T *elt,
|
||||
const bool &write_to_snapshot = false) {
|
||||
static_assert(
|
||||
std::is_base_of<StateElement, T>::value,
|
||||
"newElement accepts only StateElement based objects");
|
||||
state_map[name] = elt;
|
||||
type_map.insert(std::pair<std::string, std::type_index>(
|
||||
name, std::type_index(typeid(T))));
|
||||
toProcess[name] = elt;
|
||||
elt->name = name;
|
||||
set_save_in_snapshot(name, write_to_snapshot);
|
||||
return elt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add an element in the dictionnary.
|
||||
*
|
||||
* @param f boost::format object used to build the string-id
|
||||
* @param elt Object to add in the dictionnary. The ownership is transferred to MarkovState.
|
||||
* @param write_to_snapshot indicate, if true, that the element has to be written in mcmc files
|
||||
* @return StateElement* the same object as "elt", used to daisy chain calls.
|
||||
*/
|
||||
template <typename T>
|
||||
T *newElement(
|
||||
const boost::format &f, T *elt, const bool &write_to_snapshot = false) {
|
||||
return newElement(f.str(), elt, write_to_snapshot);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the content of a series of variables into a static array
|
||||
* That function is an helper to retrieve the value of a series "variable0",
|
||||
* "variable1", ..., "variableQ" of ScalarElement of type Scalar (with Q=N-1).
|
||||
* Such a case is for the length:
|
||||
* @code
|
||||
* double L[3];
|
||||
* state.getScalarArray<double, 3>("L", L);
|
||||
* @endcode
|
||||
* This will retrieve L0, L1 and L2 and store their value (double float) in
|
||||
* L[0], L[1], L2].
|
||||
*
|
||||
* @tparam Scalar inner type of the variable to be retrieved in the dictionnary
|
||||
* @tparam N number of elements
|
||||
* @param prefix prefix for these variables
|
||||
* @param scalars output scalar array
|
||||
*/
|
||||
template <typename Scalar, size_t N, typename ScalarArray>
|
||||
void getScalarArray(const std::string &prefix, ScalarArray &&scalars) {
|
||||
for (unsigned int i = 0; i < N; i++) {
|
||||
scalars[i] = getScalar<Scalar>(prefix + std::to_string(i));
|
||||
}
|
||||
}
|
||||
|
||||
///@deprecated
|
||||
template <typename Scalar>
|
||||
Scalar &getSyncScalar(const std::string &name) {
|
||||
return this->template get<SyncableScalarStateElement<Scalar>>(name)
|
||||
->value;
|
||||
}
|
||||
|
||||
///@deprecated
|
||||
template <typename Scalar>
|
||||
Scalar &getSyncScalar(const boost::format &name) {
|
||||
return this->template getSyncScalar<Scalar>(name.str());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the value of a scalar object.
|
||||
*
|
||||
* @tparam Scalar
|
||||
* @param name
|
||||
* @return Scalar&
|
||||
*/
|
||||
template <typename Scalar>
|
||||
Scalar &getScalar(const std::string &name) {
|
||||
return this->template get<ScalarStateElement<Scalar>>(name)->value;
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
Scalar &getScalar(const boost::format &name) {
|
||||
return this->template getScalar<Scalar>(name.str());
|
||||
}
|
||||
|
||||
template <typename Scalar, typename... U>
|
||||
Scalar &formatGetScalar(std::string const &name, U &&... u) {
|
||||
return this
|
||||
->template formatGet<ScalarStateElement<Scalar>>(
|
||||
name, std::forward<U>(u)...)
|
||||
->value;
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
ScalarStateElement<Scalar> *newScalar(
|
||||
const std::string &name, Scalar x,
|
||||
const bool &write_to_snapshot = false) {
|
||||
ScalarStateElement<Scalar> *elt = new ScalarStateElement<Scalar>();
|
||||
|
||||
elt->value = x;
|
||||
newElement(name, elt, write_to_snapshot);
|
||||
return elt;
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
ScalarStateElement<Scalar> *newScalar(
|
||||
const boost::format &name, Scalar x,
|
||||
const bool &write_to_snapshot = false) {
|
||||
return this->newScalar(name.str(), x, write_to_snapshot);
|
||||
}
|
||||
|
||||
///@deprecated
|
||||
template <typename Scalar>
|
||||
SyncableScalarStateElement<Scalar> *newSyScalar(
|
||||
const std::string &name, Scalar x,
|
||||
const bool &write_to_snapshot = false) {
|
||||
SyncableScalarStateElement<Scalar> *elt =
|
||||
new SyncableScalarStateElement<Scalar>();
|
||||
|
||||
elt->value = x;
|
||||
newElement(name, elt, write_to_snapshot);
|
||||
return elt;
|
||||
}
|
||||
|
||||
///@deprecated
|
||||
template <typename Scalar>
|
||||
SyncableScalarStateElement<Scalar> *newSyScalar(
|
||||
const boost::format &name, Scalar x,
|
||||
const bool &write_to_snapshot = false) {
|
||||
return this->newSyScalar(name.str(), x, write_to_snapshot);
|
||||
}
|
||||
|
||||
///@deprecated
|
||||
void mpiSync(MPI_Communication &comm, int root = 0) {
|
||||
namespace ph = std::placeholders;
|
||||
for (StateMap::iterator i = state_map.begin(); i != state_map.end();
|
||||
++i) {
|
||||
i->second->syncData(std::bind(
|
||||
&MPI_Communication::broadcast, comm, ph::_1, ph::_2, ph::_3, root));
|
||||
}
|
||||
}
|
||||
|
||||
void set_save_in_snapshot(const std::string &name, const bool save) {
|
||||
save_map[name] = save;
|
||||
}
|
||||
|
||||
void set_save_in_snapshot(const boost::format &name, const bool save) {
|
||||
set_save_in_snapshot(name.str(), save);
|
||||
}
|
||||
|
||||
bool get_save_in_snapshot(const std::string &name) {
|
||||
SaveMap::const_iterator i = save_map.find(name);
|
||||
if (i == save_map.end()) {
|
||||
error_helper<ErrorBadState>(
|
||||
boost::format("Invalid access to %s") % name);
|
||||
}
|
||||
return i->second;
|
||||
}
|
||||
|
||||
bool get_save_in_snapshot(const boost::format &name) {
|
||||
return get_save_in_snapshot(name.str());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Save the full content of the dictionnary into the indicated HDF5 group.
|
||||
*
|
||||
* @param fg HDF5 group/file to save the state in.
|
||||
*/
|
||||
void saveState(H5_CommonFileGroup &fg) {
|
||||
ConsoleContext<LOG_DEBUG> ctx("saveState");
|
||||
H5::Group g_scalar = fg.createGroup("scalars");
|
||||
for (auto &&i : state_map) {
|
||||
ctx.print("Saving " + i.first);
|
||||
if (i.second->isScalar())
|
||||
i.second->saveTo(g_scalar);
|
||||
else {
|
||||
H5::Group g = fg.createGroup(i.first);
|
||||
i.second->saveTo(g);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Save the full content of the dictionnary into the indicated HDF5 group.
|
||||
* This is the MPI parallel variant.
|
||||
*
|
||||
* @param fg HDF5 group/file to save the state in.
|
||||
*/
|
||||
void mpiSaveState(
|
||||
std::shared_ptr<H5_CommonFileGroup> fg, MPI_Communication *comm,
|
||||
bool reassembly, const bool write_snapshot = false) {
|
||||
ConsoleContext<LOG_VERBOSE> ctx("mpiSaveState");
|
||||
H5::Group g_scalar;
|
||||
boost::optional<H5_CommonFileGroup &> g_scalar_opt;
|
||||
|
||||
if (fg) {
|
||||
g_scalar = fg->createGroup("scalars");
|
||||
g_scalar_opt = g_scalar;
|
||||
}
|
||||
|
||||
for (auto &&i : state_map) {
|
||||
if (write_snapshot && (!get_save_in_snapshot(i.first))) {
|
||||
ctx.print("Skip saving " + i.first);
|
||||
continue;
|
||||
}
|
||||
ctx.print("Saving " + i.first);
|
||||
if (i.second->isScalar())
|
||||
i.second->saveTo(g_scalar_opt, comm, reassembly);
|
||||
else {
|
||||
H5::Group g;
|
||||
boost::optional<H5_CommonFileGroup &> g_opt;
|
||||
if (fg) {
|
||||
g = fg->createGroup(i.first);
|
||||
g_opt = g;
|
||||
}
|
||||
i.second->saveTo(g_opt, comm, reassembly);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void restoreStateWithFailure(H5_CommonFileGroup &fg) {
|
||||
Console &cons = Console::instance();
|
||||
H5::Group g_scalar = fg.openGroup("scalars");
|
||||
for (StateMap::iterator i = state_map.begin(); i != state_map.end();
|
||||
++i) {
|
||||
cons.print<LOG_VERBOSE>("Attempting to restore " + i->first);
|
||||
#if H5_VERSION_GE(1, 10, 1)
|
||||
if (!g_scalar.nameExists(i->first)) {
|
||||
cons.print<LOG_WARNING>("Failure to restore");
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
if (i->second->isScalar())
|
||||
// Partial is only valid for 'scalar' types.
|
||||
i->second->loadFrom(g_scalar, false);
|
||||
else {
|
||||
H5::Group g = fg.openGroup(i->first);
|
||||
i->second->loadFrom(g);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function to launch another function once all indicated requirements have been loaded from the
|
||||
// restart file.
|
||||
void subscribePostRestore(
|
||||
Requirements const &requirements, std::function<void()> f) {
|
||||
if (std::includes(
|
||||
requirements.begin(), requirements.end(), loaded.begin(),
|
||||
loaded.end())) {
|
||||
f();
|
||||
return;
|
||||
}
|
||||
postLoad.push_back(std::make_tuple(requirements, f));
|
||||
}
|
||||
|
||||
void triggerPostRestore(std::string const &n) {
|
||||
loaded.insert(n);
|
||||
auto i = postLoad.begin();
|
||||
while (i != postLoad.end()) {
|
||||
auto const &req = std::get<0>(*i);
|
||||
if (!std::includes(
|
||||
req.begin(), req.end(), loaded.begin(), loaded.end())) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
std::get<1> (*i)();
|
||||
auto j = i;
|
||||
++j;
|
||||
postLoad.erase(i);
|
||||
i = j;
|
||||
}
|
||||
}
|
||||
|
||||
void restoreState(
|
||||
H5_CommonFileGroup &fg, bool partial = false, bool loadSnapshot = false,
|
||||
bool acceptFailure = false) {
|
||||
Console &cons = Console::instance();
|
||||
H5::Group g_scalar = fg.openGroup("scalars");
|
||||
StateMap currentMap = state_map; // Protect against online modifications
|
||||
|
||||
do {
|
||||
for (StateMap::iterator i = currentMap.begin(); i != currentMap.end();
|
||||
++i) {
|
||||
if (loadSnapshot && !get_save_in_snapshot(i->first))
|
||||
continue;
|
||||
|
||||
cons.print<LOG_VERBOSE>("Restoring " + i->first);
|
||||
#if H5_VERSION_GE(1, 10, 1)
|
||||
if (acceptFailure && !g_scalar.nameExists(i->first)) {
|
||||
cons.print<LOG_WARNING>("Failure to restore. Skipping.");
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
if (i->second->isScalar())
|
||||
// Partial is only valid for 'scalar' types.
|
||||
i->second->loadFrom(g_scalar, partial);
|
||||
else {
|
||||
auto g = fg.openGroup(i->first);
|
||||
i->second->loadFrom(g);
|
||||
}
|
||||
triggerPostRestore(i->first);
|
||||
}
|
||||
currentMap = toProcess;
|
||||
toProcess.clear();
|
||||
} while (currentMap.size() > 0);
|
||||
|
||||
// Clear up all pending
|
||||
if (postLoad.size() > 0) {
|
||||
cons.print<LOG_ERROR>("Some post-restore triggers were not executed.");
|
||||
MPI_Communication::instance()->abort();
|
||||
}
|
||||
loaded.clear();
|
||||
postLoad.clear();
|
||||
}
|
||||
};
|
||||
|
||||
/** @example example_markov_state.cpp
|
||||
* This is an example of how to use the MarkovState class.
|
||||
*/
|
||||
|
||||
}; // namespace LibLSS
|
||||
|
||||
#endif
|
19
libLSS/mcmc/state_element.cpp
Normal file
19
libLSS/mcmc/state_element.cpp
Normal file
|
@ -0,0 +1,19 @@
|
|||
/*+
|
||||
ARES/HADES/BORG Package -- ./libLSS/mcmc/state_element.cpp
|
||||
Copyright (C) 2014-2020 Guilhem Lavaux <guilhem.lavaux@iap.fr>
|
||||
Copyright (C) 2009-2020 Jens Jasche <jens.jasche@fysik.su.se>
|
||||
|
||||
Additional contributions from:
|
||||
Guilhem Lavaux <guilhem.lavaux@iap.fr> (2023)
|
||||
|
||||
+*/
|
||||
#include <H5Cpp.h>
|
||||
#include <CosmoTool/hdf5_array.hpp>
|
||||
#include "state_element.hpp"
|
||||
|
||||
using namespace LibLSS;
|
||||
|
||||
StateElement::~StateElement()
|
||||
{
|
||||
}
|
||||
|
609
libLSS/mcmc/state_element.hpp
Normal file
609
libLSS/mcmc/state_element.hpp
Normal file
|
@ -0,0 +1,609 @@
|
|||
/*+
|
||||
ARES/HADES/BORG Package -- ./libLSS/mcmc/state_element.hpp
|
||||
Copyright (C) 2014-2020 Guilhem Lavaux <guilhem.lavaux@iap.fr>
|
||||
Copyright (C) 2009-2020 Jens Jasche <jens.jasche@fysik.su.se>
|
||||
|
||||
Additional contributions from:
|
||||
Guilhem Lavaux <guilhem.lavaux@iap.fr> (2023)
|
||||
|
||||
+*/
|
||||
#ifndef _LIBLSS_STATE_ELT_HPP
|
||||
#define _LIBLSS_STATE_ELT_HPP
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include "libLSS/tools/align_helper.hpp"
|
||||
#include "libLSS/mpi/generic_mpi.hpp"
|
||||
#include <boost/function.hpp>
|
||||
#include <boost/multi_array.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <boost/lexical_cast.hpp>
|
||||
#include <string>
|
||||
#include <H5Cpp.h>
|
||||
#include <iostream>
|
||||
#include <CosmoTool/hdf5_array.hpp>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include "libLSS/tools/errors.hpp"
|
||||
#include "libLSS/tools/memusage.hpp"
|
||||
#include "libLSS/tools/hdf5_type.hpp"
|
||||
#include "libLSS/tools/defer.hpp"
|
||||
|
||||
namespace LibLSS {
|
||||
|
||||
/**
|
||||
* @brief Generic Markov Chain State element
|
||||
* This is the base class for other more strange elements
|
||||
*
|
||||
*/
|
||||
class StateElement {
|
||||
protected:
|
||||
std::string name;
|
||||
typedef std::function<void(void *, int, MPI_Datatype)> SyncFunction;
|
||||
typedef std::function<void()> NotifyFunction;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Construct a new State Element object
|
||||
*
|
||||
*/
|
||||
StateElement() : name("_unknown_") {}
|
||||
|
||||
friend class MarkovState;
|
||||
void checkName() {
|
||||
if (name == "_unknown_") {
|
||||
std::cerr << "Name of a state element is undefined" << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Defer deferLoad, deferInit;
|
||||
|
||||
/**
|
||||
* @brief Destroy the State Element object
|
||||
*
|
||||
*/
|
||||
virtual ~StateElement();
|
||||
|
||||
/**
|
||||
* @brief Register a functor get notifications when this element is finished being loaded.
|
||||
* @deprecated Use deferLoad directly
|
||||
*
|
||||
* @param f the functor, must support copy-constructible.
|
||||
* @sa loaded
|
||||
*/
|
||||
void subscribeLoaded(NotifyFunction f) { deferLoad.ready(f); }
|
||||
|
||||
/**
|
||||
* @brief Send a message that the element has been loaded.
|
||||
* @deprecated Use deferLoad directly.
|
||||
* @sa subscribeLoaded
|
||||
*/
|
||||
void loaded() { deferLoad.submit_ready(); }
|
||||
|
||||
/**
|
||||
* @brief Get the name of this state element. This is used to store it in file.
|
||||
*
|
||||
* @return const std::string&
|
||||
*/
|
||||
const std::string &getName() const { return name; }
|
||||
|
||||
/**
|
||||
* @brief Check if this element is a scalar.
|
||||
*
|
||||
* @return true if it is a scalar, i.e. trivially serializable
|
||||
* @return false it it is not, requires a lot more operations to (de)serialize.
|
||||
*/
|
||||
virtual bool isScalar() const { return false; }
|
||||
|
||||
bool updated() { return false; }
|
||||
|
||||
/**
|
||||
* @brief Save the element to an HDF5 group, only one core is using the file.
|
||||
*
|
||||
*/
|
||||
virtual void saveTo(
|
||||
boost::optional<H5_CommonFileGroup &> fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) = 0;
|
||||
|
||||
/**
|
||||
* @brief Save the element to an HDF5 group.
|
||||
*
|
||||
* @param fg an HDF5 group/file
|
||||
* @param comm an MPI communicator
|
||||
* @param partialSave whether only the partial save is requested (i.e. generate restart file).
|
||||
*/
|
||||
virtual void saveTo(
|
||||
H5_CommonFileGroup &fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) {
|
||||
boost::optional<H5_CommonFileGroup &> o_fg = fg;
|
||||
saveTo(o_fg, comm, partialSave);
|
||||
}
|
||||
|
||||
virtual void saveTo2(
|
||||
std::shared_ptr<H5_CommonFileGroup> fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) {
|
||||
boost::optional<H5_CommonFileGroup &> o_fg;
|
||||
if (fg)
|
||||
o_fg = *fg;
|
||||
saveTo(o_fg, comm, partialSave);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param fg
|
||||
* @param partialLoad
|
||||
*/
|
||||
virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = true) = 0;
|
||||
virtual void syncData(SyncFunction f) = 0;
|
||||
};
|
||||
|
||||
/* Generic array template class for Markov Chain state element. It supports all scalars
|
||||
* and complex derived types.
|
||||
*/
|
||||
template <class AType, bool NeedReassembly = false>
|
||||
class GenericArrayStateElement : public StateElement {
|
||||
public:
|
||||
enum { Reassembly = NeedReassembly };
|
||||
typedef AType ArrayType;
|
||||
typedef typename ArrayType::element element;
|
||||
typedef typename ArrayType::index_gen index_gen;
|
||||
std::vector<hsize_t> real_dims;
|
||||
std::shared_ptr<ArrayType> array;
|
||||
bool realDimSet;
|
||||
bool resetOnSave;
|
||||
element reset_value;
|
||||
bool auto_resize;
|
||||
|
||||
bool requireReassembly() const { return (bool)Reassembly == true; }
|
||||
void setResetOnSave(const element &_reset_value) {
|
||||
this->reset_value = _reset_value;
|
||||
resetOnSave = true;
|
||||
}
|
||||
void setAutoResize(bool do_resize) { auto_resize = do_resize; }
|
||||
|
||||
template <typename ExtentDim>
|
||||
void setRealDims(const ExtentDim &d) {
|
||||
Console::instance().c_assert(
|
||||
d.size() == real_dims.size(), "Invalid dimension size");
|
||||
std::copy(d.begin(), d.end(), real_dims.begin());
|
||||
realDimSet = true;
|
||||
}
|
||||
|
||||
GenericArrayStateElement()
|
||||
: StateElement(), real_dims(ArrayType::dimensionality),
|
||||
realDimSet(false), resetOnSave(false), auto_resize(false) {}
|
||||
virtual ~GenericArrayStateElement() {}
|
||||
|
||||
virtual bool isScalar() const { return true; }
|
||||
|
||||
virtual void saveTo(
|
||||
boost::optional<H5_CommonFileGroup &> fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) {
|
||||
checkName();
|
||||
try {
|
||||
if (!requireReassembly() || partialSave) {
|
||||
ConsoleContext<LOG_DEBUG> ctx("saveTo(): saving variable " + name);
|
||||
if (partialSave || (comm != 0 && comm->rank() == 0)) {
|
||||
ctx.print("partialSave or rank==0");
|
||||
if (!fg) {
|
||||
error_helper<ErrorBadState>(
|
||||
"saveTo() requires a valid HDF5 handle on this core.");
|
||||
}
|
||||
CosmoTool::hdf5_write_array(*fg, name, *array);
|
||||
} else {
|
||||
ctx.print("Non-root rank and not partial save. Just passthrough.");
|
||||
}
|
||||
} else {
|
||||
CosmoTool::get_hdf5_data_type<element> HT;
|
||||
Console::instance().c_assert(
|
||||
comm != 0, "Array need reassembly and no communicator given");
|
||||
Console::instance().c_assert(
|
||||
realDimSet,
|
||||
"Real dimensions of the array over communicator is not set for " +
|
||||
this->getName());
|
||||
std::vector<hsize_t> remote_bases(ArrayType::dimensionality);
|
||||
std::vector<hsize_t> remote_dims(ArrayType::dimensionality);
|
||||
MPI_Datatype dt = translateMPIType<hsize_t>();
|
||||
MPI_Datatype et = translateMPIType<element>();
|
||||
|
||||
ConsoleContext<LOG_DEBUG> ctx("reassembling of variable " + name);
|
||||
|
||||
if (comm->rank() == 0) {
|
||||
if (!fg)
|
||||
error_helper<ErrorBadState>(
|
||||
"saveTo() requires a valid HDF5 handle on this core.");
|
||||
|
||||
ctx.print("Writing rank 0 data first. Dimensions = ");
|
||||
for (size_t n = 0; n < real_dims.size(); n++)
|
||||
ctx.print(boost::lexical_cast<std::string>(real_dims[n]));
|
||||
CosmoTool::hdf5_write_array(
|
||||
*fg, name, *array, HT.type(), real_dims, true, true);
|
||||
|
||||
ctx.print("Grabbing other rank data");
|
||||
for (int r = 1; r < comm->size(); r++) {
|
||||
ArrayType a;
|
||||
|
||||
ctx.print(boost::format("Incoming data from rank %d") % r);
|
||||
comm->recv(
|
||||
remote_dims.data(), ArrayType::dimensionality, dt, r, 0);
|
||||
comm->recv(
|
||||
remote_bases.data(), ArrayType::dimensionality, dt, r, 1);
|
||||
a.resize(
|
||||
CosmoTool::hdf5_extent_gen<ArrayType::dimensionality>::build(
|
||||
remote_dims.data()));
|
||||
a.reindex(remote_bases);
|
||||
comm->recv(a.data(), a.num_elements(), et, r, 2);
|
||||
CosmoTool::hdf5_write_array(
|
||||
*fg, name, a, HT.type(), real_dims, false, true);
|
||||
}
|
||||
} else {
|
||||
ctx.print("Sending data");
|
||||
comm->send(array->shape(), ArrayType::dimensionality, dt, 0, 0);
|
||||
comm->send(
|
||||
array->index_bases(), ArrayType::dimensionality, dt, 0, 1);
|
||||
comm->send(array->data(), array->num_elements(), et, 0, 2);
|
||||
}
|
||||
}
|
||||
if (resetOnSave)
|
||||
fill(reset_value);
|
||||
} catch (const H5::Exception &e) {
|
||||
error_helper<ErrorIO>(e.getDetailMsg());
|
||||
}
|
||||
}
|
||||
|
||||
virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = false) {
|
||||
checkName();
|
||||
try {
|
||||
if (!requireReassembly() || !partialLoad) {
|
||||
ConsoleContext<LOG_DEBUG> ctx("loadFrom full");
|
||||
ctx.print(
|
||||
boost::format("loadFrom(reassembly=%d,partialLoad=%d,autoresize=%"
|
||||
"d): loading variable %s") %
|
||||
requireReassembly() % partialLoad % auto_resize % name);
|
||||
ctx.print("partialSave or rank==0");
|
||||
CosmoTool::hdf5_read_array(fg, name, *array, auto_resize);
|
||||
} else {
|
||||
Console::instance().c_assert(
|
||||
realDimSet,
|
||||
"Real dimensions of the array over communicator is not set for " +
|
||||
this->getName());
|
||||
std::vector<hsize_t> remote_bases(ArrayType::dimensionality);
|
||||
std::vector<hsize_t> remote_dims(ArrayType::dimensionality);
|
||||
|
||||
ConsoleContext<LOG_DEBUG> ctx("dissassembling of variable " + name);
|
||||
CosmoTool::hdf5_read_array(fg, name, *array, false, true);
|
||||
}
|
||||
} catch (const CosmoTool::InvalidDimensions &) {
|
||||
error_helper<ErrorBadState>(
|
||||
boost::format("Incompatible array size loading '%s'") % getName());
|
||||
} catch (const H5::GroupIException &) {
|
||||
error_helper<ErrorIO>(
|
||||
"Could not open variable " + getName() + " in state file");
|
||||
} catch (const H5::DataSetIException &error) {
|
||||
error_helper<ErrorIO>(
|
||||
"Could not open variable " + getName() + " in state file");
|
||||
}
|
||||
loaded();
|
||||
}
|
||||
|
||||
virtual void syncData(SyncFunction f) {
|
||||
typename ArrayType::size_type S;
|
||||
f(array->data(), array->num_elements(),
|
||||
translateMPIType<typename AType::element>());
|
||||
}
|
||||
|
||||
virtual void fill(const element &v) {
|
||||
//#pragma omp simd
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < array->num_elements(); i++)
|
||||
array->data()[i] = v;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T, std::size_t DIMENSIONS,
|
||||
typename Allocator = LibLSS::track_allocator<T>,
|
||||
bool NeedReassembly = false>
|
||||
class ArrayStateElement
|
||||
: public GenericArrayStateElement<
|
||||
boost::multi_array<T, DIMENSIONS, Allocator>, NeedReassembly> {
|
||||
typedef GenericArrayStateElement<
|
||||
boost::multi_array<T, DIMENSIONS, Allocator>, NeedReassembly>
|
||||
super_type;
|
||||
|
||||
public:
|
||||
typedef typename super_type::ArrayType ArrayType;
|
||||
typedef typename boost::multi_array_ref<T, DIMENSIONS> RefArrayType;
|
||||
typedef typename super_type::index_gen index_gen;
|
||||
|
||||
enum { AlignState = DetectAlignment<Allocator>::Align };
|
||||
typedef Eigen::Array<T, Eigen::Dynamic, 1> E_Array;
|
||||
typedef Eigen::Map<E_Array, AlignState> MapArray;
|
||||
|
||||
template <typename ExtentList>
|
||||
ArrayStateElement(
|
||||
const ExtentList &extents, const Allocator &allocator = Allocator(),
|
||||
const boost::general_storage_order<DIMENSIONS> &ordering =
|
||||
boost::c_storage_order())
|
||||
: super_type() {
|
||||
this->array = std::make_shared<ArrayType>(extents, ordering, allocator);
|
||||
Console::instance().print<LOG_DEBUG>(
|
||||
std::string("Creating array which is ") +
|
||||
((((int)AlignState == (int)Eigen::Aligned) ? "ALIGNED"
|
||||
: "UNALIGNED")));
|
||||
}
|
||||
|
||||
MapArray eigen() {
|
||||
return MapArray(this->array->data(), this->array->num_elements());
|
||||
}
|
||||
|
||||
virtual void fill(const typename super_type::element &v) {
|
||||
eigen().fill(v);
|
||||
}
|
||||
|
||||
// This is unsafe. Use it with precaution
|
||||
void unsafeSetName(const std::string &n) { this->name = n; }
|
||||
};
|
||||
|
||||
template <typename T, std::size_t DIMENSIONS>
|
||||
class RefArrayStateElement
|
||||
: public GenericArrayStateElement<boost::multi_array_ref<T, DIMENSIONS>> {
|
||||
public:
|
||||
typedef boost::multi_array_ref<T, DIMENSIONS> ArrayType;
|
||||
typedef boost::multi_array_ref<T, DIMENSIONS> RefArrayType;
|
||||
|
||||
template <typename ExtentList>
|
||||
RefArrayStateElement(
|
||||
T *data, const ExtentList &extents,
|
||||
const boost::general_storage_order<DIMENSIONS> &ordering =
|
||||
boost::c_storage_order())
|
||||
: StateElement() {
|
||||
this->array = std::make_shared<ArrayType>(data, extents);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct _scalar_writer {
|
||||
template <typename DT>
|
||||
static inline void write(H5::DataSet &dataset, U &v, DT dt) {
|
||||
dataset.write(&v, dt.type());
|
||||
}
|
||||
|
||||
template <typename DT>
|
||||
static inline void read(H5::DataSet &dataset, U &v, DT dt) {
|
||||
dataset.read(&v, dt.type());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _scalar_writer<std::string> {
|
||||
template <typename DT>
|
||||
static inline void write(H5::DataSet &dataset, std::string &v, DT dt) {
|
||||
dataset.write(v, dt.type());
|
||||
}
|
||||
|
||||
template <typename DT>
|
||||
static inline void read(H5::DataSet &dataset, std::string &v, DT dt) {
|
||||
dataset.read(v, dt.type());
|
||||
}
|
||||
};
|
||||
|
||||
/* Generic scalar Markov State element. */
|
||||
template <typename T>
|
||||
class ScalarStateElement : public StateElement {
|
||||
public:
|
||||
T value;
|
||||
T reset_value;
|
||||
bool resetOnSave;
|
||||
bool doNotRestore;
|
||||
|
||||
ScalarStateElement()
|
||||
: StateElement(), value(), reset_value(), resetOnSave(false),
|
||||
doNotRestore(false) {}
|
||||
virtual ~ScalarStateElement() {}
|
||||
|
||||
void setDoNotRestore(bool doNotRestore) {
|
||||
this->doNotRestore = doNotRestore;
|
||||
}
|
||||
void setResetOnSave(const T &_reset_value) {
|
||||
this->reset_value = _reset_value;
|
||||
resetOnSave = true;
|
||||
}
|
||||
|
||||
virtual void saveTo(
|
||||
boost::optional<H5_CommonFileGroup &> fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) {
|
||||
CosmoTool::get_hdf5_data_type<T> hdf_data_type;
|
||||
std::vector<hsize_t> dimensions(1);
|
||||
dimensions[0] = 1;
|
||||
|
||||
if (partialSave || (comm != 0 && comm->rank() == 0)) {
|
||||
checkName();
|
||||
H5::DataSpace dataspace(1, dimensions.data());
|
||||
H5::DataSet dataset =
|
||||
(*fg).createDataSet(name, hdf_data_type.type(), dataspace);
|
||||
|
||||
_scalar_writer<T>::write(dataset, value, hdf_data_type);
|
||||
if (resetOnSave)
|
||||
value = reset_value;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = true) {
|
||||
CosmoTool::get_hdf5_data_type<T> hdf_data_type;
|
||||
std::vector<hsize_t> dimensions(1);
|
||||
H5::DataSet dataset;
|
||||
|
||||
if (doNotRestore) {
|
||||
return;
|
||||
}
|
||||
|
||||
dimensions[0] = 1;
|
||||
|
||||
checkName();
|
||||
try {
|
||||
dataset = fg.openDataSet(name);
|
||||
} catch (const H5::GroupIException &) {
|
||||
error_helper<ErrorIO>(
|
||||
"Could not find variable " + name + " in state file.");
|
||||
}
|
||||
H5::DataSpace dataspace = dataset.getSpace();
|
||||
hsize_t n;
|
||||
|
||||
if (dataspace.getSimpleExtentNdims() != 1)
|
||||
error_helper<ErrorIO>("Invalid stored dimension for " + getName());
|
||||
|
||||
dataspace.getSimpleExtentDims(&n);
|
||||
if (n != 1)
|
||||
error_helper<ErrorIO>("Invalid stored dimension for " + getName());
|
||||
|
||||
_scalar_writer<T>::read(dataset, value, hdf_data_type);
|
||||
loaded();
|
||||
}
|
||||
|
||||
operator T() { return value; }
|
||||
|
||||
virtual bool isScalar() const { return true; }
|
||||
|
||||
virtual void syncData(SyncFunction f) {
|
||||
error_helper<ErrorBadState>(
|
||||
"MPI synchronization not supported by this type");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SyncableScalarStateElement : public ScalarStateElement<T> {
|
||||
public:
|
||||
typedef typename ScalarStateElement<T>::SyncFunction SyncFunction;
|
||||
|
||||
virtual void syncData(SyncFunction f) {
|
||||
f(&this->value, 1, translateMPIType<T>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SharedObjectStateElement : public StateElement {
|
||||
public:
|
||||
std::shared_ptr<T> obj;
|
||||
|
||||
SharedObjectStateElement() : StateElement() {}
|
||||
SharedObjectStateElement(std::shared_ptr<T> &src)
|
||||
: StateElement(), obj(src) {}
|
||||
SharedObjectStateElement(std::shared_ptr<T> &&src)
|
||||
: StateElement(), obj(src) {}
|
||||
virtual ~SharedObjectStateElement() {}
|
||||
|
||||
virtual void saveTo(
|
||||
boost::optional<CosmoTool::H5_CommonFileGroup &> fg,
|
||||
MPI_Communication *comm = 0, bool partialSave = true) {
|
||||
if (fg)
|
||||
obj->save(*fg);
|
||||
}
|
||||
|
||||
virtual void
|
||||
loadFrom(CosmoTool::H5_CommonFileGroup &fg, bool partialSave = true) {
|
||||
obj->restore(fg);
|
||||
loaded();
|
||||
}
|
||||
|
||||
operator T &() { return *obj; }
|
||||
|
||||
T &get() { return *obj; }
|
||||
const T &get() const { return *obj; }
|
||||
|
||||
virtual void syncData(SyncFunction f) {}
|
||||
};
|
||||
|
||||
template <typename T, bool autofree>
|
||||
class ObjectStateElement : public StateElement {
|
||||
public:
|
||||
T *obj;
|
||||
|
||||
ObjectStateElement() : StateElement() {}
|
||||
ObjectStateElement(T *o) : StateElement(), obj(o) {}
|
||||
virtual ~ObjectStateElement() {
|
||||
if (autofree)
|
||||
delete obj;
|
||||
}
|
||||
|
||||
virtual void saveTo(
|
||||
boost::optional<H5_CommonFileGroup &> fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) {
|
||||
if (fg)
|
||||
obj->save(*fg);
|
||||
}
|
||||
|
||||
virtual void loadFrom(H5_CommonFileGroup &fg, bool partialSave = true) {
|
||||
obj->restore(fg);
|
||||
loaded();
|
||||
}
|
||||
|
||||
operator T &() { return *obj; }
|
||||
|
||||
T &get() { return *obj; }
|
||||
const T &get() const { return *obj; }
|
||||
|
||||
virtual void syncData(SyncFunction f) {}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class TemporaryElement : public StateElement {
|
||||
protected:
|
||||
T obj;
|
||||
|
||||
public:
|
||||
TemporaryElement(T const &a) : obj(a) {}
|
||||
TemporaryElement(T &&a) : obj(a) {}
|
||||
|
||||
operator T &() { return obj; }
|
||||
|
||||
T &get() { return obj; }
|
||||
const T &get() const { return obj; }
|
||||
|
||||
virtual void saveTo(
|
||||
boost::optional<H5_CommonFileGroup &> fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) {}
|
||||
|
||||
virtual void loadFrom(H5_CommonFileGroup &fg, bool partialSave = true) {}
|
||||
|
||||
virtual void syncData(SyncFunction f) {}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class RandomStateElement : public StateElement {
|
||||
protected:
|
||||
std::shared_ptr<T> rng;
|
||||
|
||||
public:
|
||||
RandomStateElement(T *generator, bool handover_ = false) {
|
||||
if (handover_) {
|
||||
rng = std::shared_ptr<T>(generator, [](T *a) { delete a; });
|
||||
} else {
|
||||
rng = std::shared_ptr<T>(generator, [](T *a) {});
|
||||
}
|
||||
}
|
||||
RandomStateElement(std::shared_ptr<T> generator) : rng(generator) {}
|
||||
virtual ~RandomStateElement() {}
|
||||
|
||||
const T &get() const { return *rng; }
|
||||
T &get() { return *rng; }
|
||||
|
||||
virtual void saveTo(
|
||||
boost::optional<H5_CommonFileGroup &> fg, MPI_Communication *comm = 0,
|
||||
bool partialSave = true) {
|
||||
if (fg)
|
||||
rng->save(*fg);
|
||||
}
|
||||
|
||||
virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = false) {
|
||||
rng->restore(fg, partialLoad);
|
||||
loaded();
|
||||
}
|
||||
|
||||
virtual void syncData(SyncFunction f) {}
|
||||
};
|
||||
}; // namespace LibLSS
|
||||
|
||||
#endif
|
88
libLSS/mcmc/state_sync.hpp
Normal file
88
libLSS/mcmc/state_sync.hpp
Normal file
|
@ -0,0 +1,88 @@
|
|||
/*+
|
||||
ARES/HADES/BORG Package -- ./libLSS/mcmc/state_sync.hpp
|
||||
Copyright (C) 2014-2020 Guilhem Lavaux <guilhem.lavaux@iap.fr>
|
||||
Copyright (C) 2009-2020 Jens Jasche <jens.jasche@fysik.su.se>
|
||||
|
||||
Additional contributions from:
|
||||
Guilhem Lavaux <guilhem.lavaux@iap.fr> (2023)
|
||||
|
||||
+*/
|
||||
#ifndef __LIBLSS_STATE_ELEMENT_SYNC_HPP
|
||||
#define __LIBLSS_STATE_ELEMENT_SYNC_HPP
|
||||
|
||||
#include <functional>
|
||||
#include "libLSS/tools/console.hpp"
|
||||
#include "libLSS/mpi/generic_mpi.hpp"
|
||||
|
||||
namespace LibLSS {
|
||||
|
||||
class StateElement;
|
||||
|
||||
/**
|
||||
* Helper class to synchronize many StateElement variable at the same time with MPI.
|
||||
* @deprecated
|
||||
*/
|
||||
class MPI_SyncBundle {
|
||||
protected:
|
||||
typedef std::list<StateElement *> List;
|
||||
|
||||
List list;
|
||||
public:
|
||||
/// Constructor
|
||||
MPI_SyncBundle() {}
|
||||
~MPI_SyncBundle() {}
|
||||
|
||||
/**
|
||||
* Add a specified element to the bundle.
|
||||
* @param e the element to be added
|
||||
*/
|
||||
MPI_SyncBundle& operator+=(StateElement *e) {
|
||||
list.push_back(e);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the provided synchronization function on all elements
|
||||
* of the bundle.
|
||||
* @param f the Functor to be executed.
|
||||
*/
|
||||
template<typename Function>
|
||||
void syncData(Function f) {
|
||||
ConsoleContext<LOG_DEBUG> ctx("sync bundle");
|
||||
for (List::iterator i = list.begin(); i != list.end(); ++i)
|
||||
(*i)->syncData(f);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a broadcast operation on the bundle.
|
||||
* @param comm the MPI communicator.
|
||||
* @param root the root for the broadcast operation (default is 0).
|
||||
*/
|
||||
void mpiBroadcast(MPI_Communication& comm, int root = 0) {
|
||||
namespace ph = std::placeholders;
|
||||
syncData(std::bind(&MPI_Communication::broadcast, comm, ph::_1, ph::_2, ph::_3, root));
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a all reduce (max) operation on the bundle.
|
||||
* @param comm the MPI communicator.
|
||||
*/
|
||||
void mpiAllMax(MPI_Communication& comm) {
|
||||
namespace ph = std::placeholders;
|
||||
syncData(std::bind(&MPI_Communication::all_reduce, comm, MPI_IN_PLACE, ph::_1, ph::_2, ph::_3, MPI_MAX));
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a all reduce (sum) operation on the bundle.
|
||||
* @param comm the MPI communicator.
|
||||
*/
|
||||
void mpiAllSum(MPI_Communication& comm) {
|
||||
namespace ph = std::placeholders;
|
||||
syncData(std::bind(&MPI_Communication::all_reduce, comm, MPI_IN_PLACE, ph::_1, ph::_2, ph::_3, MPI_SUM));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
#endif
|
Loading…
Add table
Add a link
Reference in a new issue