/*+ ARES/HADES/BORG Package -- ./libLSS/mcmc/state_element.hpp Copyright (C) 2014-2020 Guilhem Lavaux Copyright (C) 2009-2020 Jens Jasche Additional contributions from: Guilhem Lavaux (2023) +*/ #ifndef _LIBLSS_STATE_ELT_HPP #define _LIBLSS_STATE_ELT_HPP #include #include "libLSS/tools/align_helper.hpp" #include "libLSS/mpi/generic_mpi.hpp" #include #include #include #include #include #include #include #include #include #include #include "libLSS/tools/errors.hpp" #include "libLSS/tools/memusage.hpp" #include "libLSS/tools/hdf5_type.hpp" #include "libLSS/tools/defer.hpp" namespace LibLSS { /** * @brief Generic Markov Chain State element * This is the base class for other more strange elements * */ class StateElement { protected: std::string name; typedef std::function SyncFunction; typedef std::function NotifyFunction; protected: /** * @brief Construct a new State Element object * */ StateElement() : name("_unknown_") {} friend class MarkovState; void checkName() { if (name == "_unknown_") { std::cerr << "Name of a state element is undefined" << std::endl; abort(); } } public: Defer deferLoad, deferInit; /** * @brief Destroy the State Element object * */ virtual ~StateElement(); /** * @brief Register a functor get notifications when this element is finished being loaded. * @deprecated Use deferLoad directly * * @param f the functor, must support copy-constructible. * @sa loaded */ void subscribeLoaded(NotifyFunction f) { deferLoad.ready(f); } /** * @brief Send a message that the element has been loaded. * @deprecated Use deferLoad directly. * @sa subscribeLoaded */ void loaded() { deferLoad.submit_ready(); } /** * @brief Get the name of this state element. This is used to store it in file. * * @return const std::string& */ const std::string &getName() const { return name; } /** * @brief Check if this element is a scalar. * * @return true if it is a scalar, i.e. trivially serializable * @return false it it is not, requires a lot more operations to (de)serialize. */ virtual bool isScalar() const { return false; } bool updated() { return false; } /** * @brief Save the element to an HDF5 group, only one core is using the file. * */ virtual void saveTo( boost::optional fg, MPI_Communication *comm = 0, bool partialSave = true) = 0; /** * @brief Save the element to an HDF5 group. * * @param fg an HDF5 group/file * @param comm an MPI communicator * @param partialSave whether only the partial save is requested (i.e. generate restart file). */ virtual void saveTo( H5_CommonFileGroup &fg, MPI_Communication *comm = 0, bool partialSave = true) { boost::optional o_fg = fg; saveTo(o_fg, comm, partialSave); } virtual void saveTo2( std::shared_ptr fg, MPI_Communication *comm = 0, bool partialSave = true) { boost::optional o_fg; if (fg) o_fg = *fg; saveTo(o_fg, comm, partialSave); } /** * @brief * * @param fg * @param partialLoad */ virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = true) = 0; virtual void syncData(SyncFunction f) = 0; }; /* Generic array template class for Markov Chain state element. It supports all scalars * and complex derived types. */ template class GenericArrayStateElement : public StateElement { public: enum { Reassembly = NeedReassembly }; typedef AType ArrayType; typedef typename ArrayType::element element; typedef typename ArrayType::index_gen index_gen; std::vector real_dims; std::shared_ptr array; bool realDimSet; bool resetOnSave; element reset_value; bool auto_resize; bool requireReassembly() const { return (bool)Reassembly == true; } void setResetOnSave(const element &_reset_value) { this->reset_value = _reset_value; resetOnSave = true; } void setAutoResize(bool do_resize) { auto_resize = do_resize; } template void setRealDims(const ExtentDim &d) { Console::instance().c_assert( d.size() == real_dims.size(), "Invalid dimension size"); std::copy(d.begin(), d.end(), real_dims.begin()); realDimSet = true; } GenericArrayStateElement() : StateElement(), real_dims(ArrayType::dimensionality), realDimSet(false), resetOnSave(false), auto_resize(false) {} virtual ~GenericArrayStateElement() {} virtual bool isScalar() const { return true; } virtual void saveTo( boost::optional fg, MPI_Communication *comm = 0, bool partialSave = true) { checkName(); try { if (!requireReassembly() || partialSave) { ConsoleContext ctx("saveTo(): saving variable " + name); if (partialSave || (comm != 0 && comm->rank() == 0)) { ctx.print("partialSave or rank==0"); if (!fg) { error_helper( "saveTo() requires a valid HDF5 handle on this core."); } CosmoTool::hdf5_write_array(*fg, name, *array); } else { ctx.print("Non-root rank and not partial save. Just passthrough."); } } else { CosmoTool::get_hdf5_data_type HT; Console::instance().c_assert( comm != 0, "Array need reassembly and no communicator given"); Console::instance().c_assert( realDimSet, "Real dimensions of the array over communicator is not set for " + this->getName()); std::vector remote_bases(ArrayType::dimensionality); std::vector remote_dims(ArrayType::dimensionality); MPI_Datatype dt = translateMPIType(); MPI_Datatype et = translateMPIType(); ConsoleContext ctx("reassembling of variable " + name); if (comm->rank() == 0) { if (!fg) error_helper( "saveTo() requires a valid HDF5 handle on this core."); ctx.print("Writing rank 0 data first. Dimensions = "); for (size_t n = 0; n < real_dims.size(); n++) ctx.print(boost::lexical_cast(real_dims[n])); CosmoTool::hdf5_write_array( *fg, name, *array, HT.type(), real_dims, true, true); ctx.print("Grabbing other rank data"); for (int r = 1; r < comm->size(); r++) { ArrayType a; ctx.print(boost::format("Incoming data from rank %d") % r); comm->recv( remote_dims.data(), ArrayType::dimensionality, dt, r, 0); comm->recv( remote_bases.data(), ArrayType::dimensionality, dt, r, 1); a.resize( CosmoTool::hdf5_extent_gen::build( remote_dims.data())); a.reindex(remote_bases); comm->recv(a.data(), a.num_elements(), et, r, 2); CosmoTool::hdf5_write_array( *fg, name, a, HT.type(), real_dims, false, true); } } else { ctx.print("Sending data"); comm->send(array->shape(), ArrayType::dimensionality, dt, 0, 0); comm->send( array->index_bases(), ArrayType::dimensionality, dt, 0, 1); comm->send(array->data(), array->num_elements(), et, 0, 2); } } if (resetOnSave) fill(reset_value); } catch (const H5::Exception &e) { error_helper(e.getDetailMsg()); } } virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = false) { checkName(); try { if (!requireReassembly() || !partialLoad) { ConsoleContext ctx("loadFrom full"); ctx.print( boost::format("loadFrom(reassembly=%d,partialLoad=%d,autoresize=%" "d): loading variable %s") % requireReassembly() % partialLoad % auto_resize % name); ctx.print("partialSave or rank==0"); CosmoTool::hdf5_read_array(fg, name, *array, auto_resize); } else { Console::instance().c_assert( realDimSet, "Real dimensions of the array over communicator is not set for " + this->getName()); std::vector remote_bases(ArrayType::dimensionality); std::vector remote_dims(ArrayType::dimensionality); ConsoleContext ctx("dissassembling of variable " + name); CosmoTool::hdf5_read_array(fg, name, *array, false, true); } } catch (const CosmoTool::InvalidDimensions &) { error_helper( boost::format("Incompatible array size loading '%s'") % getName()); } catch (const H5::GroupIException &) { error_helper( "Could not open variable " + getName() + " in state file"); } catch (const H5::DataSetIException &error) { error_helper( "Could not open variable " + getName() + " in state file"); } loaded(); } virtual void syncData(SyncFunction f) { typename ArrayType::size_type S; f(array->data(), array->num_elements(), translateMPIType()); } virtual void fill(const element &v) { //#pragma omp simd #pragma omp parallel for for (size_t i = 0; i < array->num_elements(); i++) array->data()[i] = v; } }; template < typename T, std::size_t DIMENSIONS, typename Allocator = LibLSS::track_allocator, bool NeedReassembly = false> class ArrayStateElement : public GenericArrayStateElement< boost::multi_array, NeedReassembly> { typedef GenericArrayStateElement< boost::multi_array, NeedReassembly> super_type; public: typedef typename super_type::ArrayType ArrayType; typedef typename boost::multi_array_ref RefArrayType; typedef typename super_type::index_gen index_gen; enum { AlignState = DetectAlignment::Align }; typedef Eigen::Array E_Array; typedef Eigen::Map MapArray; template ArrayStateElement( const ExtentList &extents, const Allocator &allocator = Allocator(), const boost::general_storage_order &ordering = boost::c_storage_order()) : super_type() { this->array = std::make_shared(extents, ordering, allocator); Console::instance().print( std::string("Creating array which is ") + ((((int)AlignState == (int)Eigen::Aligned) ? "ALIGNED" : "UNALIGNED"))); } MapArray eigen() { return MapArray(this->array->data(), this->array->num_elements()); } virtual void fill(const typename super_type::element &v) { eigen().fill(v); } // This is unsafe. Use it with precaution void unsafeSetName(const std::string &n) { this->name = n; } }; template class RefArrayStateElement : public GenericArrayStateElement> { public: typedef boost::multi_array_ref ArrayType; typedef boost::multi_array_ref RefArrayType; template RefArrayStateElement( T *data, const ExtentList &extents, const boost::general_storage_order &ordering = boost::c_storage_order()) : StateElement() { this->array = std::make_shared(data, extents); } }; template struct _scalar_writer { template static inline void write(H5::DataSet &dataset, U &v, DT dt) { dataset.write(&v, dt.type()); } template static inline void read(H5::DataSet &dataset, U &v, DT dt) { dataset.read(&v, dt.type()); } }; template <> struct _scalar_writer { template static inline void write(H5::DataSet &dataset, std::string &v, DT dt) { dataset.write(v, dt.type()); } template static inline void read(H5::DataSet &dataset, std::string &v, DT dt) { dataset.read(v, dt.type()); } }; /* Generic scalar Markov State element. */ template class ScalarStateElement : public StateElement { public: T value; T reset_value; bool resetOnSave; bool doNotRestore; ScalarStateElement() : StateElement(), value(), reset_value(), resetOnSave(false), doNotRestore(false) {} virtual ~ScalarStateElement() {} void setDoNotRestore(bool doNotRestore) { this->doNotRestore = doNotRestore; } void setResetOnSave(const T &_reset_value) { this->reset_value = _reset_value; resetOnSave = true; } virtual void saveTo( boost::optional fg, MPI_Communication *comm = 0, bool partialSave = true) { CosmoTool::get_hdf5_data_type hdf_data_type; std::vector dimensions(1); dimensions[0] = 1; if (partialSave || (comm != 0 && comm->rank() == 0)) { checkName(); H5::DataSpace dataspace(1, dimensions.data()); H5::DataSet dataset = (*fg).createDataSet(name, hdf_data_type.type(), dataspace); _scalar_writer::write(dataset, value, hdf_data_type); if (resetOnSave) value = reset_value; } } virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = true) { CosmoTool::get_hdf5_data_type hdf_data_type; std::vector dimensions(1); H5::DataSet dataset; if (doNotRestore) { return; } dimensions[0] = 1; checkName(); try { dataset = fg.openDataSet(name); } catch (const H5::GroupIException &) { error_helper( "Could not find variable " + name + " in state file."); } H5::DataSpace dataspace = dataset.getSpace(); hsize_t n; if (dataspace.getSimpleExtentNdims() != 1) error_helper("Invalid stored dimension for " + getName()); dataspace.getSimpleExtentDims(&n); if (n != 1) error_helper("Invalid stored dimension for " + getName()); _scalar_writer::read(dataset, value, hdf_data_type); loaded(); } operator T() { return value; } virtual bool isScalar() const { return true; } virtual void syncData(SyncFunction f) { error_helper( "MPI synchronization not supported by this type"); } }; template class SyncableScalarStateElement : public ScalarStateElement { public: typedef typename ScalarStateElement::SyncFunction SyncFunction; virtual void syncData(SyncFunction f) { f(&this->value, 1, translateMPIType()); } }; template class SharedObjectStateElement : public StateElement { public: std::shared_ptr obj; SharedObjectStateElement() : StateElement() {} SharedObjectStateElement(std::shared_ptr &src) : StateElement(), obj(src) {} SharedObjectStateElement(std::shared_ptr &&src) : StateElement(), obj(src) {} virtual ~SharedObjectStateElement() {} virtual void saveTo( boost::optional fg, MPI_Communication *comm = 0, bool partialSave = true) { if (fg) obj->save(*fg); } virtual void loadFrom(CosmoTool::H5_CommonFileGroup &fg, bool partialSave = true) { obj->restore(fg); loaded(); } operator T &() { return *obj; } T &get() { return *obj; } const T &get() const { return *obj; } virtual void syncData(SyncFunction f) {} }; template class ObjectStateElement : public StateElement { public: T *obj; ObjectStateElement() : StateElement() {} ObjectStateElement(T *o) : StateElement(), obj(o) {} virtual ~ObjectStateElement() { if (autofree) delete obj; } virtual void saveTo( boost::optional fg, MPI_Communication *comm = 0, bool partialSave = true) { if (fg) obj->save(*fg); } virtual void loadFrom(H5_CommonFileGroup &fg, bool partialSave = true) { obj->restore(fg); loaded(); } operator T &() { return *obj; } T &get() { return *obj; } const T &get() const { return *obj; } virtual void syncData(SyncFunction f) {} }; template class TemporaryElement : public StateElement { protected: T obj; public: TemporaryElement(T const &a) : obj(a) {} TemporaryElement(T &&a) : obj(a) {} operator T &() { return obj; } T &get() { return obj; } const T &get() const { return obj; } virtual void saveTo( boost::optional fg, MPI_Communication *comm = 0, bool partialSave = true) {} virtual void loadFrom(H5_CommonFileGroup &fg, bool partialSave = true) {} virtual void syncData(SyncFunction f) {} }; template class RandomStateElement : public StateElement { protected: std::shared_ptr rng; public: RandomStateElement(T *generator, bool handover_ = false) { if (handover_) { rng = std::shared_ptr(generator, [](T *a) { delete a; }); } else { rng = std::shared_ptr(generator, [](T *a) {}); } } RandomStateElement(std::shared_ptr generator) : rng(generator) {} virtual ~RandomStateElement() {} const T &get() const { return *rng; } T &get() { return *rng; } virtual void saveTo( boost::optional fg, MPI_Communication *comm = 0, bool partialSave = true) { if (fg) rng->save(*fg); } virtual void loadFrom(H5_CommonFileGroup &fg, bool partialLoad = false) { rng->restore(fg, partialLoad); loaded(); } virtual void syncData(SyncFunction f) {} }; }; // namespace LibLSS #endif