Commit 7b774048 authored by Marcus Wirtz's avatar Marcus Wirtz
Browse files

[cosmic_rays] Fix squeezing bug for nsets=1

parent c4e0c21a
Pipeline #225216 passed with stages
in 6 minutes and 19 seconds
......@@ -402,8 +402,9 @@ class CosmicRaysSets(CosmicRaysBase):
except ValueError as e:
if len(self._similar_key(key)) > 0:
value = self._get_values_similar_key(self._similar_key(key).pop(), key)
if value.size in (np.prod(self.shape), 3 * np.prod(self.shape)):
return np.squeeze(np.reshape(value, (-1, self.nsets, self.ncrs)))
if value.size in (np.prod(self.shape), 3*np.prod(self.shape)):
shape = self.shape if value.size == np.prod(self.shape) else (-1,)+self.shape
return np.reshape(value, shape)
raise Exception("Weird error occured, please report this incident with a minimal example!")
raise ValueError("The key %s does not exist and the error message was %s" % (key, str(e)))
......
......@@ -988,7 +988,14 @@ class TestCosmicRaysSets(unittest.TestCase):
self.assertTrue(np.shape(pvals) == (shape[0], 60))
self.assertTrue(((pvals >= 0) & (pvals <= 1)).all())
def test_32_shuffle(self):
def test_32_single_set(self):
shape = (1, 100)
crs = CosmicRaysSets(shape)
crs['pixel'] = np.ones(shape).astype(int)
self.assertTrue(np.shape(crs['vecs']) == (3,) + shape)
self.assertTrue(np.shape(crs['lon']) == shape)
def test_33_shuffle(self):
crs = CosmicRaysSets(self.shape)
crs["log10e"] = np.random.random(self.shape)
crs["vecs"] = coord.rand_vec(self.shape)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment