Hardened SimuData. Support writeGadget

This commit is contained in:
Guilhem Lavaux 2014-05-30 19:00:25 +02:00
parent adf14da4b4
commit f4187185a7
2 changed files with 99 additions and 15 deletions

View file

@ -1,4 +1,5 @@
from libcpp cimport bool
from libcpp cimport string as cppstring
import numpy as np
cimport numpy as np
from cpython cimport PyObject, Py_INCREF
@ -22,6 +23,8 @@ cdef extern from "loadSimu.hpp" namespace "CosmoTool":
float *Vel[3]
int *type
bool noAuto
cdef const int NEED_GADGET_ID
cdef const int NEED_POSITION
cdef const int NEED_VELOCITY
@ -30,10 +33,35 @@ cdef extern from "loadSimu.hpp" namespace "CosmoTool":
cdef extern from "loadGadget.hpp" namespace "CosmoTool":
SimuData *loadGadgetMulti(const char *fname, int id, int flags) except +
void cxx_writeGadget "CosmoTool::writeGadget" (const char * s, SimuData *data) except +
cdef extern from "loadRamses.hpp" namespace "CosmoTool":
SimuData *loadRamsesSimu(const char *basename, int id, int cpuid, bool dp, int flags) except +
class PySimulationBase(object):
def getPositions(self):
raise NotImplemented("getPositions is not implemented")
def getVelocities(self):
raise NotImplemented("getVelocities is not implemented")
def getIdentifiers(self):
raise NotImplemented("getIdentifiers is not implemented")
def getOmega_M(self):
raise NotImplemented("getOmega_M is not implemented")
def getOmega_Lambda(self):
raise NotImplemented("getOmega_Lambda is not implemented")
def getTime(self):
raise NotImplemented("getTime is not implemented")
def getHubble(self):
raise NotImplemented("getHubble is not implemented")
cdef class Simulation:
cdef list positions
@ -79,6 +107,27 @@ cdef class Simulation:
del self.data
class _PySimulationAdaptor(PySimulationBase):
def __init__(self,sim):
self.simu = sim
def getPositions(self):
return self.simu.positions
def getVelocities(self):
return self.simu.velocities
def getIdentifiers(self):
return self.simu.identifiers
def getTime(self):
return self.simu.time
def getHubble(self):
return self.simul.Hubble
cdef class ArrayWrapper:
cdef void* data_ptr
cdef int size
@ -167,7 +216,39 @@ def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loa
if data == <SimuData*>0:
return None
return wrap_simudata(data, flags)
return _PySimulationAdaptor(wrap_simudata(data, flags))
def writeGadget(str filename, object simulation):
cdef SimuData simdata
cdef np.ndarray[np.float_t, ndim=1] pos, vel
cdef np.int64_t NumPart
if not isinstance(simulation,PySimulationBase):
raise TypeError("Second argument must be of type SimulationBase")
NumPart = simulation.positions[0].size
simdata.noAuto = True
for j in xrange(3):
pos = simulation.getPositions()[j]
vel = simulation.getVelocities()[j]
if pos.size != NumPart or vel.size != NumPart:
raise ValueError("Invalid number of particles")
simdata.Pos[j] = <float *>pos.data
simdata.Vel[j] = <float *>vel.data
simdata.BoxSize = simulation.getBoxSize()
simdata.time = simulation.getTime()
simdata.Hubble = simulation.getHubble()
simdata.Omega_M = simulation.getOmega_M()
simdata.Omega_Lambda = simulation.getOmega_Lambda()
simdata.TotalNumPart = NumPart
simdata.NumPart = NumPart
cxx_writeGadget(filename, &simdata)
def loadRamses(str basepath, int snapshot_id, int cpu_id, bool doublePrecision = False, bool loadPosition = True, bool loadVelocity = False):
""" loadRamses(basepath, snapshot_id, cpu_id, doublePrecision=False, loadPosition=True, loadVelocity=False)
@ -187,4 +268,4 @@ def loadRamses(str basepath, int snapshot_id, int cpu_id, bool doublePrecision =
if data == <SimuData*>0:
return None
return wrap_simudata(data, flags)
return _PySimulationAdaptor(wrap_simudata(data, flags))

View file

@ -64,6 +64,8 @@ namespace CosmoTool
typedef void (*FreeFunction)(void *);
typedef std::map<std::string, std::pair<void *, FreeFunction> > AttributeMap;
bool noAuto;
float BoxSize;
float time;
float Hubble;
@ -81,21 +83,22 @@ namespace CosmoTool
AttributeMap attributes;
public:
SimuData() : Id(0),NumPart(0),type(0) { Pos[0]=Pos[1]=Pos[2]=0; Vel[0]=Vel[1]=Vel[2]=0; }
SimuData() : Id(0),NumPart(0),type(0),noAuto(false) { Pos[0]=Pos[1]=Pos[2]=0; Vel[0]=Vel[1]=Vel[2]=0; }
~SimuData()
{
for (int j = 0; j < 3; j++)
{
if (Pos[j])
delete[] Pos[j];
if (Vel[j])
delete[] Vel[j];
}
if (type)
delete[] type;
if (Id)
delete[] Id;
if (!noAuto) {
for (int j = 0; j < 3; j++)
{
if (Pos[j])
delete[] Pos[j];
if (Vel[j])
delete[] Vel[j];
}
if (type)
delete[] type;
if (Id)
delete[] Id;
}
for (AttributeMap::iterator i = attributes.begin();
i != attributes.end();
++i)