Hardened SimuData. Support writeGadget

This commit is contained in:
Guilhem Lavaux 2014-05-30 19:00:25 +02:00
parent adf14da4b4
commit f4187185a7
2 changed files with 99 additions and 15 deletions

View File

@ -1,4 +1,5 @@
from libcpp cimport bool from libcpp cimport bool
from libcpp cimport string as cppstring
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
from cpython cimport PyObject, Py_INCREF from cpython cimport PyObject, Py_INCREF
@ -22,6 +23,8 @@ cdef extern from "loadSimu.hpp" namespace "CosmoTool":
float *Vel[3] float *Vel[3]
int *type int *type
bool noAuto
cdef const int NEED_GADGET_ID cdef const int NEED_GADGET_ID
cdef const int NEED_POSITION cdef const int NEED_POSITION
cdef const int NEED_VELOCITY cdef const int NEED_VELOCITY
@ -30,10 +33,35 @@ cdef extern from "loadSimu.hpp" namespace "CosmoTool":
cdef extern from "loadGadget.hpp" namespace "CosmoTool": cdef extern from "loadGadget.hpp" namespace "CosmoTool":
SimuData *loadGadgetMulti(const char *fname, int id, int flags) except + SimuData *loadGadgetMulti(const char *fname, int id, int flags) except +
void cxx_writeGadget "CosmoTool::writeGadget" (const char * s, SimuData *data) except +
cdef extern from "loadRamses.hpp" namespace "CosmoTool": cdef extern from "loadRamses.hpp" namespace "CosmoTool":
SimuData *loadRamsesSimu(const char *basename, int id, int cpuid, bool dp, int flags) except + SimuData *loadRamsesSimu(const char *basename, int id, int cpuid, bool dp, int flags) except +
class PySimulationBase(object):
def getPositions(self):
raise NotImplemented("getPositions is not implemented")
def getVelocities(self):
raise NotImplemented("getVelocities is not implemented")
def getIdentifiers(self):
raise NotImplemented("getIdentifiers is not implemented")
def getOmega_M(self):
raise NotImplemented("getOmega_M is not implemented")
def getOmega_Lambda(self):
raise NotImplemented("getOmega_Lambda is not implemented")
def getTime(self):
raise NotImplemented("getTime is not implemented")
def getHubble(self):
raise NotImplemented("getHubble is not implemented")
cdef class Simulation: cdef class Simulation:
cdef list positions cdef list positions
@ -79,6 +107,27 @@ cdef class Simulation:
del self.data del self.data
class _PySimulationAdaptor(PySimulationBase):
def __init__(self,sim):
self.simu = sim
def getPositions(self):
return self.simu.positions
def getVelocities(self):
return self.simu.velocities
def getIdentifiers(self):
return self.simu.identifiers
def getTime(self):
return self.simu.time
def getHubble(self):
return self.simul.Hubble
cdef class ArrayWrapper: cdef class ArrayWrapper:
cdef void* data_ptr cdef void* data_ptr
cdef int size cdef int size
@ -167,7 +216,39 @@ def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loa
if data == <SimuData*>0: if data == <SimuData*>0:
return None return None
return wrap_simudata(data, flags) return _PySimulationAdaptor(wrap_simudata(data, flags))
def writeGadget(str filename, object simulation):
cdef SimuData simdata
cdef np.ndarray[np.float_t, ndim=1] pos, vel
cdef np.int64_t NumPart
if not isinstance(simulation,PySimulationBase):
raise TypeError("Second argument must be of type SimulationBase")
NumPart = simulation.positions[0].size
simdata.noAuto = True
for j in xrange(3):
pos = simulation.getPositions()[j]
vel = simulation.getVelocities()[j]
if pos.size != NumPart or vel.size != NumPart:
raise ValueError("Invalid number of particles")
simdata.Pos[j] = <float *>pos.data
simdata.Vel[j] = <float *>vel.data
simdata.BoxSize = simulation.getBoxSize()
simdata.time = simulation.getTime()
simdata.Hubble = simulation.getHubble()
simdata.Omega_M = simulation.getOmega_M()
simdata.Omega_Lambda = simulation.getOmega_Lambda()
simdata.TotalNumPart = NumPart
simdata.NumPart = NumPart
cxx_writeGadget(filename, &simdata)
def loadRamses(str basepath, int snapshot_id, int cpu_id, bool doublePrecision = False, bool loadPosition = True, bool loadVelocity = False): def loadRamses(str basepath, int snapshot_id, int cpu_id, bool doublePrecision = False, bool loadPosition = True, bool loadVelocity = False):
""" loadRamses(basepath, snapshot_id, cpu_id, doublePrecision=False, loadPosition=True, loadVelocity=False) """ loadRamses(basepath, snapshot_id, cpu_id, doublePrecision=False, loadPosition=True, loadVelocity=False)
@ -187,4 +268,4 @@ def loadRamses(str basepath, int snapshot_id, int cpu_id, bool doublePrecision =
if data == <SimuData*>0: if data == <SimuData*>0:
return None return None
return wrap_simudata(data, flags) return _PySimulationAdaptor(wrap_simudata(data, flags))

View File

@ -64,6 +64,8 @@ namespace CosmoTool
typedef void (*FreeFunction)(void *); typedef void (*FreeFunction)(void *);
typedef std::map<std::string, std::pair<void *, FreeFunction> > AttributeMap; typedef std::map<std::string, std::pair<void *, FreeFunction> > AttributeMap;
bool noAuto;
float BoxSize; float BoxSize;
float time; float time;
float Hubble; float Hubble;
@ -81,9 +83,10 @@ namespace CosmoTool
AttributeMap attributes; AttributeMap attributes;
public: public:
SimuData() : Id(0),NumPart(0),type(0) { Pos[0]=Pos[1]=Pos[2]=0; Vel[0]=Vel[1]=Vel[2]=0; } SimuData() : Id(0),NumPart(0),type(0),noAuto(false) { Pos[0]=Pos[1]=Pos[2]=0; Vel[0]=Vel[1]=Vel[2]=0; }
~SimuData() ~SimuData()
{ {
if (!noAuto) {
for (int j = 0; j < 3; j++) for (int j = 0; j < 3; j++)
{ {
if (Pos[j]) if (Pos[j])
@ -95,7 +98,7 @@ namespace CosmoTool
delete[] type; delete[] type;
if (Id) if (Id)
delete[] Id; delete[] Id;
}
for (AttributeMap::iterator i = attributes.begin(); for (AttributeMap::iterator i = attributes.begin();
i != attributes.end(); i != attributes.end();
++i) ++i)