grab particle ids too

This commit is contained in:
Guilhem Lavaux 2014-05-30 13:21:12 +02:00
parent bf4977721f
commit 2afccc3411

View file

@ -38,6 +38,7 @@ cdef class Simulation:
cdef list positions
cdef list velocities
cdef object identifiers
cdef SimuData *data
@ -45,6 +46,10 @@ cdef class Simulation:
def __get__(Simulation self):
return self.data.BoxSize
property time:
def __get__(Simulation self):
return self.data.time
property Hubble:
def __get__(Simulation self):
return self.data.Hubble
@ -77,8 +82,9 @@ cdef class Simulation:
cdef class ArrayWrapper:
cdef void* data_ptr
cdef int size
cdef int type_array
cdef set_data(self, int size, void* data_ptr):
cdef set_data(self, int size, int type_array, void* data_ptr):
""" Set the data of the array
This cannot be done in the constructor as it must recieve C-level
@ -94,6 +100,7 @@ Pointer to the data
"""
self.data_ptr = data_ptr
self.size = size
self.type_array = type_array
def __array__(self):
""" Here we use the __array__ method, that is called when numpy
@ -102,7 +109,7 @@ tries to get an array from the object."""
shape[0] = <np.npy_intp> self.size
# Create a 1D array, of length 'size'
ndarray = np.PyArray_SimpleNewFromData(1, shape, np.NPY_FLOAT, self.data_ptr)
ndarray = np.PyArray_SimpleNewFromData(1, shape, self.type_array, self.data_ptr)
return ndarray
def __dealloc__(self):
@ -110,18 +117,25 @@ tries to get an array from the object."""
references to the object are gone. """
pass
cdef object wrap_float_array(float *p, np.uint64_t s):
cdef object wrap_array(void *p, np.uint64_t s, int typ):
cdef np.ndarray ndarray
cdef ArrayWrapper wrapper
wrapper = ArrayWrapper()
wrapper.set_data(s, <void *>p)
wrapper.set_data(s, typ, p)
ndarray = np.array(wrapper, copy=False)
ndarray.base = <PyObject*> wrapper
Py_INCREF(wrapper)
return ndarray
cdef object wrap_float_array(float *p, np.uint64_t s):
return wrap_array(<void *>p, s, np.NPY_FLOAT)
cdef object wrap_int64_array(np.int64_t* p, np.uint64_t s):
return wrap_array(<void *>p, s, np.NPY_INT64)
cdef object wrap_simudata(SimuData *data, int flags):
cdef Simulation simu
@ -131,9 +145,11 @@ cdef object wrap_simudata(SimuData *data, int flags):
simu.positions = [wrap_float_array(data.Pos[i], data.NumPart) for i in xrange(3)]
if flags & NEED_VELOCITY:
simu.velocities = [wrap_float_array(data.Vel[i], data.NumPart) for i in xrange(3)]
if flags & NEED_GADGET_ID:
simu.identifiers = wrap_int64_array(data.Id, data.NumPart)
return simu
def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loadVelocity = False):
def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loadVelocity = False, bool loadId = False):
cdef int flags
cdef SimuData *data
@ -144,6 +160,8 @@ def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loa
flags |= NEED_POSITION
if loadVelocity:
flags |= NEED_VELOCITY
if loadId:
flags |= NEED_GADGET_ID
data = loadGadgetMulti(filename, snapshot_id, flags)
if data == <SimuData*>0: