Generic fix and path update

This commit is contained in:
Guilhem Lavaux 2017-07-11 15:36:19 +02:00
parent 019480c0e0
commit 07cfe4137f
6 changed files with 27 additions and 8 deletions

View file

@ -4,14 +4,23 @@ import numpy as np
import numexpr as ne
class CubeFT(object):
def __init__(self, L, N, max_cpu=-1):
def __init__(self, L, N, max_cpu=-1, width=32):
if width==32:
fourier_type='complex64'
real_type='float32'
elif width==64:
fourier_type='complex128'
real_type='float64'
else:
raise ValueError("Invalid bitwidth (must be 32 or 64)")
self.N = N
self.align = pyfftw.simd_alignment
self.L = float(L)
self.max_cpu = multiprocessing.cpu_count() if max_cpu < 0 else max_cpu
self._dhat = pyfftw.n_byte_align_empty((self.N,self.N,self.N/2+1), self.align, dtype='complex64')
self._density = pyfftw.n_byte_align_empty((self.N,self.N,self.N), self.align, dtype='float32')
self._dhat = pyfftw.n_byte_align_empty((self.N,self.N,self.N/2+1), self.align, dtype=fourier_type)
self._density = pyfftw.n_byte_align_empty((self.N,self.N,self.N), self.align, dtype=real_type)
self._irfft = pyfftw.FFTW(self._dhat, self._density, axes=(0,1,2), direction='FFTW_BACKWARD', threads=self.max_cpu)#, normalize_idft=False)
self._rfft = pyfftw.FFTW(self._density, self._dhat, axes=(0,1,2), threads=self.max_cpu) #, normalize_idft=False)

View file

@ -14,6 +14,7 @@ class SimulationBare(PySimulationBase):
self.positions = [q.copy() for q in s.getPositions()] if s.getPositions() is not None else None
self.velocities = [q.copy() for q in s.getVelocities()] if s.getVelocities() is not None else None
self.identifiers = s.getIdentifiers().copy() if s.getIdentifiers() is not None else None
self.types = s.getTypes().copy() if s.getTypes() is not None else None
self.boxsize = s.getBoxsize()
self.time = s.getTime()
self.Hubble = s.getHubble()
@ -53,11 +54,15 @@ class SimulationBare(PySimulationBase):
self.positions = _safe_merge(self.positions, other.getPositions())
self.velocities = _safe_merge(self.velocities, other.getVelocities())
self.identifiers = _safe_merge0(self.identifiers, other.getIdentifiers())
self.types = _safe_merge0(self.types, other.getTypes())
try:
self.masses = _safe_merge0(self.masses, other.getMasses())
except Exception as e:
warnings.warn("Unexpected exception: " + repr(e));
self.masses = None
def getTypes(self):
return self.types
def getPositions(self):
return self.positions