diff --git a/python/_cosmotool.pyx b/python/_cosmotool.pyx index fdc934e..091b9a4 100644 --- a/python/_cosmotool.pyx +++ b/python/_cosmotool.pyx @@ -38,6 +38,7 @@ cdef class Simulation: cdef list positions cdef list velocities + cdef object identifiers cdef SimuData *data @@ -45,6 +46,10 @@ cdef class Simulation: def __get__(Simulation self): return self.data.BoxSize + property time: + def __get__(Simulation self): + return self.data.time + property Hubble: def __get__(Simulation self): return self.data.Hubble @@ -77,8 +82,9 @@ cdef class Simulation: cdef class ArrayWrapper: cdef void* data_ptr cdef int size + cdef int type_array - cdef set_data(self, int size, void* data_ptr): + cdef set_data(self, int size, int type_array, void* data_ptr): """ Set the data of the array This cannot be done in the constructor as it must recieve C-level @@ -94,6 +100,7 @@ Pointer to the data """ self.data_ptr = data_ptr self.size = size + self.type_array = type_array def __array__(self): """ Here we use the __array__ method, that is called when numpy @@ -102,7 +109,7 @@ tries to get an array from the object.""" shape[0] = self.size # Create a 1D array, of length 'size' - ndarray = np.PyArray_SimpleNewFromData(1, shape, np.NPY_FLOAT, self.data_ptr) + ndarray = np.PyArray_SimpleNewFromData(1, shape, self.type_array, self.data_ptr) return ndarray def __dealloc__(self): @@ -110,18 +117,25 @@ tries to get an array from the object.""" references to the object are gone. """ pass -cdef object wrap_float_array(float *p, np.uint64_t s): +cdef object wrap_array(void *p, np.uint64_t s, int typ): cdef np.ndarray ndarray cdef ArrayWrapper wrapper wrapper = ArrayWrapper() - wrapper.set_data(s, p) + wrapper.set_data(s, typ, p) ndarray = np.array(wrapper, copy=False) ndarray.base = wrapper Py_INCREF(wrapper) return ndarray + +cdef object wrap_float_array(float *p, np.uint64_t s): + return wrap_array(p, s, np.NPY_FLOAT) + +cdef object wrap_int64_array(np.int64_t* p, np.uint64_t s): + return wrap_array(p, s, np.NPY_INT64) + cdef object wrap_simudata(SimuData *data, int flags): cdef Simulation simu @@ -131,9 +145,11 @@ cdef object wrap_simudata(SimuData *data, int flags): simu.positions = [wrap_float_array(data.Pos[i], data.NumPart) for i in xrange(3)] if flags & NEED_VELOCITY: simu.velocities = [wrap_float_array(data.Vel[i], data.NumPart) for i in xrange(3)] + if flags & NEED_GADGET_ID: + simu.identifiers = wrap_int64_array(data.Id, data.NumPart) return simu -def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loadVelocity = False): +def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loadVelocity = False, bool loadId = False): cdef int flags cdef SimuData *data @@ -144,6 +160,8 @@ def loadGadget(str filename, int snapshot_id, bool loadPosition = True, bool loa flags |= NEED_POSITION if loadVelocity: flags |= NEED_VELOCITY + if loadId: + flags |= NEED_GADGET_ID data = loadGadgetMulti(filename, snapshot_id, flags) if data == 0: