Commit 88357cb3 authored by Lukas Weber's avatar Lukas Weber

make mcextract more userfriendly

parent 2f7cdf33
......@@ -20,7 +20,7 @@ class MCArchive:
doc = json.load(f)
param_names = set(sum([list(task['parameters'].keys()) for task in doc], []))
observable_names = set(sum([list(task['results'].keys()) for task in doc], []))
observable_names = set(sum([list(task['results'].keys()) if task['results'] != None else [] for task in doc], []))
self.num_tasks = len(doc)
self.parameters = dict(zip(param_names, [[None for _ in range(self.num_tasks)] for _ in param_names]))
......@@ -30,7 +30,8 @@ class MCArchive:
for param, value in task['parameters'].items():
self.parameters[param][i] = value
for obs, value in task['results'].items():
results = task['results'] if task['results'] else {}
for obs, value in results.items():
o = self.observables[obs]
o.rebinning_bin_length[i] = int(value.get('rebin_len',0))
o.rebinning_bin_count[i] = int(value.get('rebin_count',0))
......@@ -47,9 +48,18 @@ class MCArchive:
def get_parameter(self, name, unique=False, filter={}):
selection = list(itertools.compress(self.parameters[name], self.filter_mask(filter)))
if len(selection) == 0:
raise KeyError('Parameter {} not found with filter {}'.format(name,filter))
if unique:
return list(sorted(set(selection)))
selection = list(sorted(set(selection)))
dtypes = set(type(p) for p in selection)
if len(dtypes) == 1:
dtype = list(dtypes)[0]
if dtype == float or dtype == int:
selection = np.array(selection)
return selection
def get_observable(self, name, filter={}):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment