Commit cf9794e7 authored by Marcus Wirtz's avatar Marcus Wirtz
Browse files

[cosmic_rays] Fix inheritance bug when slicing a CosmicRaysBase object

parent 26573921
Pipeline #149407 passed with stages
in 4 minutes and 1 second
......@@ -109,12 +109,25 @@ class CosmicRaysBase(container.DataContainer):
self.type = "CosmicRays"
def __getitem__(self, key):
try:
return super(CosmicRaysBase, self).__getitem__(key)
except ValueError:
if len(self._similar_key(key)) > 0:
return self._get_values_similar_key(self._similar_key(key).pop(), key)
raise ValueError("Key '%s' does not exist, no info stored under similar keys was found" % key)
if isinstance(key, (int, np.integer, np.ndarray, slice)):
crs = CosmicRaysBase(self.shape_array[key])
for k in self.general_object_store.keys():
to_copy = self.get(k)
if isinstance(to_copy, (np.ndarray, list)):
if len(to_copy) == self.ncrs:
to_copy = to_copy[key]
crs.__setitem__(k, to_copy)
return crs
if key in self.general_object_store.keys():
return self.general_object_store[key]
if key in self.shape_array.dtype.names:
return self.shape_array[key]
if len(self._similar_key(key)) > 0:
return self._get_values_similar_key(self._similar_key(key).pop(), key)
raise ValueError("Key '%s' does not exist, no info stored under similar keys was found" % key)
def __setitem__(self, key, value):
if key in self.shape_array.dtype.names:
......
......@@ -335,7 +335,7 @@ class PlotSkyPatch:
def __init__(self, lon_roi, lat_roi, r_roi, ax=None, title=None, **kwargs):
"""
:param lon_roi: Longitude of center of ROI in radians (0..2*pi)
:param lat_roi: Latitude of center of ROI in radians (0..2*pi)
:param lat_roi: Latitude of center of ROI in radians (-pi/2 .. pi/2)
:param r_roi: Radius of ROI to be plotted (in radians)
:param ax: Matplotlib axes in case you want to plot on certain axes
:param title: Optional title of plot (plotted in upper left corner)
......
......@@ -301,6 +301,8 @@ class TestCosmicRays(unittest.TestCase):
self.assertTrue(hasattr(crs_sub, 'keys'))
self.assertTrue(len(crs_sub) < self.ncrs)
self.assertTrue(len(crs_sub['energy']) == len(crs_sub))
self.assertTrue(crs.type == 'CosmicRays')
self.assertTrue(crs_sub.type == 'CosmicRays')
def test_26_set_unfortunate_length_of_string(self):
_str = 'hallo'
......@@ -751,6 +753,7 @@ class TestCosmicRaysSets(unittest.TestCase):
def test_23_mask_nsets_ncrs(self):
nsets, ncrs = 5, 100
crs = CosmicRaysSets((nsets, ncrs))
self.assertTrue(crs.type == 'CosmicRaysSet')
energies = np.linspace(0, 100, ncrs)
crs['energy'] = energies
mask = np.zeros((nsets, ncrs), dtype=bool)
......@@ -758,6 +761,7 @@ class TestCosmicRaysSets(unittest.TestCase):
crs = crs[mask]
self.assertTrue(crs.shape == (nsets, 70))
self.assertTrue(crs.ncrs == 70)
self.assertTrue(crs.type == 'CosmicRaysSet')
def test_24_mask_arbitrary(self):
crs = CosmicRaysSets(self.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment