cosmotool/python/_cosmotool.pyx

301 lines
7.7 KiB
Cython
Raw Normal View History

2014-05-25 11:11:44 +02:00
from libcpp cimport bool
2014-05-30 19:00:25 +02:00
from libcpp cimport string as cppstring
2014-05-25 10:43:06 +02:00
import numpy as np
2014-05-25 11:11:44 +02:00
cimport numpy as np
2014-05-26 14:24:53 +02:00
from cpython cimport PyObject, Py_INCREF
np.import_array()
2014-05-25 11:11:44 +02:00
cdef extern from "loadSimu.hpp" namespace "CosmoTool":
cdef cppclass SimuData:
np.float_t BoxSize
np.float_t time
np.float_t Hubble
np.float_t Omega_M
np.float_t Omega_Lambda
2014-05-26 14:24:53 +02:00
np.int64_t TotalNumPart
np.int64_t NumPart
2014-05-25 11:11:44 +02:00
np.int64_t *Id
2014-05-26 14:24:53 +02:00
float *Pos[3]
float *Vel[3]
2014-05-25 11:11:44 +02:00
int *type
2014-05-30 19:00:25 +02:00
bool noAuto
2014-05-25 11:11:44 +02:00
cdef const int NEED_GADGET_ID
cdef const int NEED_POSITION
cdef const int NEED_VELOCITY
cdef const int NEED_TYPE
cdef extern from "loadGadget.hpp" namespace "CosmoTool":
SimuData *loadGadgetMulti(const char *fname, int id, int flags) except +
2014-05-30 19:00:25 +02:00
void cxx_writeGadget "CosmoTool::writeGadget" (const char * s, SimuData *data) except +
2014-05-25 11:11:44 +02:00
cdef extern from "loadRamses.hpp" namespace "CosmoTool":
SimuData *loadRamsesSimu(const char *basename, int id, int cpuid, bool dp, int flags) except +
2014-05-25 11:11:44 +02:00
2014-05-30 19:00:25 +02:00
class PySimulationBase(object):
def getPositions(self):
raise NotImplementedError("getPositions is not implemented")
2014-05-30 19:00:25 +02:00
def getVelocities(self):
raise NotImplementedError("getVelocities is not implemented")
2014-05-30 19:00:25 +02:00
def getIdentifiers(self):
raise NotImplementedError("getIdentifiers is not implemented")
2014-05-30 19:00:25 +02:00
def getOmega_M(self):
raise NotImplementedError("getOmega_M is not implemented")
2014-05-30 19:00:25 +02:00
def getOmega_Lambda(self):
raise NotImplementedError("getOmega_Lambda is not implemented")
2014-05-30 19:00:25 +02:00
def getTime(self):
raise NotImplementedError("getTime is not implemented")
2014-05-30 19:00:25 +02:00
def getHubble(self):
raise NotImplementedError("getHubble is not implemented")
def getBoxsize(self):
raise NotImplementedError("getBoxsize is not implemented")
2014-05-30 19:00:25 +02:00
2014-05-25 11:11:44 +02:00
cdef class Simulation:
2014-05-26 14:24:53 +02:00
cdef list positions
cdef list velocities
2014-05-30 13:21:12 +02:00
cdef object identifiers
2014-05-25 11:11:44 +02:00
cdef SimuData *data
2014-05-26 14:24:53 +02:00
property BoxSize:
def __get__(Simulation self):
return self.data.BoxSize
2014-05-30 13:21:12 +02:00
property time:
def __get__(Simulation self):
return self.data.time
2014-05-26 14:24:53 +02:00
property Hubble:
def __get__(Simulation self):
return self.data.Hubble
property Omega_M:
def __get__(Simulation self):
return self.data.Omega_M
property Omega_Lambda:
def __get__(Simulation self):
return self.data.Omega_Lambda
2014-05-26 14:24:53 +02:00
property positions:
def __get__(Simulation self):
return self.positions
property velocities:
def __get__(Simulation self):
return self.velocities
property identifiers:
def __get__(Simulation self):
return self.identifiers
2014-05-26 14:24:53 +02:00
property numParticles:
def __get__(Simulation self):
return self.data.NumPart
2014-05-25 11:11:44 +02:00
def __cinit__(Simulation self):
self.data = <SimuData *>0
def __dealloc__(Simulation self):
if self.data != <SimuData *>0:
2014-05-26 14:24:53 +02:00
print("Clearing simulation data")
2014-05-25 11:11:44 +02:00
del self.data
2014-05-26 14:24:53 +02:00
2014-05-30 19:00:25 +02:00
class _PySimulationAdaptor(PySimulationBase):
def __init__(self,sim):
self.simu = sim
def getPositions(self):
return self.simu.positions
def getVelocities(self):
return self.simu.velocities
def getIdentifiers(self):
return self.simu.identifiers
def getTime(self):
return self.simu.time
def getHubble(self):
return self.simu.Hubble
def getOmega_M(self):
return self.simu.Omega_M
def getOmega_Lambda(self):
return self.simu.Omega_Lambda
2014-05-30 19:00:25 +02:00
2014-05-26 14:24:53 +02:00
cdef class ArrayWrapper:
cdef void* data_ptr
cdef int size
2014-05-30 13:21:12 +02:00
cdef int type_array
2014-05-26 14:24:53 +02:00
2014-05-30 13:21:12 +02:00
cdef set_data(self, int size, int type_array, void* data_ptr):
2014-05-26 14:24:53 +02:00
""" Set the data of the array
This cannot be done in the constructor as it must recieve C-level
arguments.
Parameters:
-----------
size: int
Length of the array.
data_ptr: void*
Pointer to the data
"""
self.data_ptr = data_ptr
self.size = size
2014-05-30 13:21:12 +02:00
self.type_array = type_array
2014-05-26 14:24:53 +02:00
def __array__(self):
""" Here we use the __array__ method, that is called when numpy
tries to get an array from the object."""
cdef np.npy_intp shape[1]
shape[0] = <np.npy_intp> self.size
# Create a 1D array, of length 'size'
2014-05-30 13:21:12 +02:00
ndarray = np.PyArray_SimpleNewFromData(1, shape, self.type_array, self.data_ptr)
2014-05-26 14:24:53 +02:00
return ndarray
def __dealloc__(self):
""" Frees the array. This is called by Python when all the
references to the object are gone. """
pass
2014-05-30 13:21:12 +02:00
cdef object wrap_array(void *p, np.uint64_t s, int typ):
2014-05-26 14:24:53 +02:00
cdef np.ndarray ndarray
cdef ArrayWrapper wrapper
wrapper = ArrayWrapper()
2014-05-30 13:21:12 +02:00
wrapper.set_data(s, typ, p)
2014-05-26 14:24:53 +02:00
ndarray = np.array(wrapper, copy=False)
ndarray.base = <PyObject*> wrapper
Py_INCREF(wrapper)
return ndarray
2014-05-30 13:21:12 +02:00
cdef object wrap_float_array(float *p, np.uint64_t s):
return wrap_array(<void *>p, s, np.NPY_FLOAT32)
2014-05-30 13:21:12 +02:00
cdef object wrap_int64_array(np.int64_t* p, np.uint64_t s):
return wrap_array(<void *>p, s, np.NPY_INT64)
cdef object wrap_simudata(SimuData *data, int flags):
cdef Simulation simu
simu = Simulation()
simu.data = data
if flags & NEED_POSITION:
simu.positions = [wrap_float_array(data.Pos[i], data.NumPart) for i in xrange(3)]
else:
simu.positions = None
if flags & NEED_VELOCITY:
simu.velocities = [wrap_float_array(data.Vel[i], data.NumPart) for i in xrange(3)]
else:
simu.velocities = None
2014-05-30 13:21:12 +02:00
if flags & NEED_GADGET_ID:
simu.identifiers = wrap_int64_array(data.Id, data.NumPart)
else:
simu.identifiers = None
return simu
2014-05-30 13:21:12 +02:00
def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loadVelocity = False, bool loadId = False):
2014-05-25 11:11:44 +02:00
cdef int flags
cdef SimuData *data
cdef Simulation simu
flags = 0
if loadPosition:
flags |= NEED_POSITION
if loadVelocity:
flags |= NEED_VELOCITY
2014-05-30 13:21:12 +02:00
if loadId:
flags |= NEED_GADGET_ID
2014-05-25 11:11:44 +02:00
data = loadGadgetMulti(filename, snapshot_id, flags)
2014-05-27 13:54:42 +02:00
if data == <SimuData*>0:
return None
2014-05-25 11:11:44 +02:00
2014-05-30 19:00:25 +02:00
return _PySimulationAdaptor(wrap_simudata(data, flags))
def writeGadget(str filename, object simulation):
cdef SimuData simdata
cdef np.ndarray[np.float32_t, ndim=1] pos, vel
cdef np.ndarray[np.int64_t, ndim=1] ids
2014-05-30 19:00:25 +02:00
cdef np.int64_t NumPart
cdef int j
2014-05-30 19:00:25 +02:00
if not isinstance(simulation,PySimulationBase):
raise TypeError("Second argument must be of type SimulationBase")
NumPart = simulation.positions[0].size
simdata.noAuto = True
for j in xrange(3):
pos = simulation.getPositions()[j]
vel = simulation.getVelocities()[j]
if pos.size != NumPart or vel.size != NumPart:
raise ValueError("Invalid number of particles")
simdata.Pos[j] = <float *>pos.data
simdata.Vel[j] = <float *>vel.data
ids = simulation.getIdentifiers()
simdata.Id = <np.int64_t *>ids.data
simdata.BoxSize = simulation.getBoxsize()
2014-05-30 19:00:25 +02:00
simdata.time = simulation.getTime()
simdata.Hubble = simulation.getHubble()
simdata.Omega_M = simulation.getOmega_M()
simdata.Omega_Lambda = simulation.getOmega_Lambda()
simdata.TotalNumPart = NumPart
simdata.NumPart = NumPart
cxx_writeGadget(filename, &simdata)
def loadRamses(str basepath, int snapshot_id, int cpu_id, bool doublePrecision = False, bool loadPosition = True, bool loadVelocity = False):
""" loadRamses(basepath, snapshot_id, cpu_id, doublePrecision=False, loadPosition=True, loadVelocity=False)
Loads the indicated snapshot based on the cpu id, snapshot id and basepath. It is important to specify the correct precision in doublePrecision.
"""
cdef int flags
cdef SimuData *data
cdef Simulation simu
flags = 0
2014-05-26 14:24:53 +02:00
if loadPosition:
flags |= NEED_POSITION
2014-05-26 14:24:53 +02:00
if loadVelocity:
flags |= NEED_VELOCITY
data = loadRamsesSimu(basepath, snapshot_id, cpu_id, doublePrecision, flags)
if data == <SimuData*>0:
return None
2014-05-30 19:00:25 +02:00
return _PySimulationAdaptor(wrap_simudata(data, flags))