Skip to content
Snippets Groups Projects
Commit db3017c2 authored by Tobias Hangleiter's avatar Tobias Hangleiter
Browse files

Merge branch 'bugfixes_and_improvements' into 'main'

Bugfixes and improvements

See merge request !68
parents be7e6a5f 5d1093be
No related branches found
No related tags found
No related merge requests found
Pipeline #1676947 failed
......@@ -31,7 +31,7 @@ classifiers = [
dependencies = [
"packaging",
"lazy-loader",
"qutech-util >= 2025.03.1",
"qutech-util @ git+https://git.rwth-aachen.de/qutech/qutil.git@master",
"numpy",
"scipy",
"matplotlib >= 3.7",
......
"""This module defines the PlotManager helper class."""
import contextlib
import os
import sys
import warnings
import weakref
from itertools import compress
from typing import (Dict, Any, Optional, Mapping, Tuple, ContextManager, Iterable, Union, List,
Literal)
Literal, TypeVar, Callable)
import matplotlib.pyplot as plt
import numpy as np
from cycler import Cycler
from matplotlib import gridspec, scale
from qutil.functools import wraps
from qutil.itertools import compress
from qutil.misc import filter_warnings
from qutil.plotting import assert_interactive_figure
from scipy import integrate, signal
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
_T = TypeVar('_T')
_P = ParamSpec('_P')
_keyT = Union[int, str, Tuple[int, str]]
_styleT = Union[str, os.PathLike, dict]
_styleT = Union[None, _styleT, List[_styleT]]
def with_plot_context(meth: Callable[_P, _T]) -> Callable[_P, _T]:
@wraps(meth)
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
self, *args = args
with self.plot_context:
return meth(self, *args, **kwargs)
return wrapped
class PlotManager:
__instances = weakref.WeakSet()
......@@ -32,11 +52,11 @@ class PlotManager:
plot_negative_frequencies: bool = True, plot_absolute_frequencies: bool = True,
plot_amplitude: bool = True, plot_density: bool = True,
plot_cumulative_normalized: bool = False, plot_style: _styleT = 'fast',
plot_dB_scale: bool = False, threaded_acquisition: bool = True, prop_cycle=None,
raw_unit: str = 'V', processed_unit: Optional[str] = None,
uses_windowed_estimator: bool = True, complex_data: Optional[bool] = False,
figure_kw: Optional[Mapping] = None, subplot_kw: Optional[Mapping] = None,
gridspec_kw: Optional[Mapping] = None, legend_kw: Optional[Mapping] = None):
plot_dB_scale: bool = False, prop_cycle=None, raw_unit: str = 'V',
processed_unit: Optional[str] = None, uses_windowed_estimator: bool = True,
complex_data: Optional[bool] = False, figure_kw: Optional[Mapping] = None,
subplot_kw: Optional[Mapping] = None, gridspec_kw: Optional[Mapping] = None,
legend_kw: Optional[Mapping] = None):
"""A helper class that manages plotting spectrometer data."""
self._data = data
......@@ -51,13 +71,12 @@ class PlotManager:
self._plot_cumulative_normalized = plot_cumulative_normalized
self._plot_style = plot_style
self._plot_dB_scale = plot_dB_scale
self._threaded_acquisition = threaded_acquisition
self._processed_unit = processed_unit if processed_unit is not None else raw_unit
self._prop_cycle = prop_cycle
# For dB scale plots, default to the first spectrum acquired.
self._reference_spectrum: Optional[_keyT] = None
self.prop_cycle = prop_cycle or plt.rcParams['axes.prop_cycle']
self.raw_unit = raw_unit
self.uses_windowed_estimator = uses_windowed_estimator
self._complex_data = complex_data
......@@ -90,10 +109,13 @@ class PlotManager:
return self._fig is not None and plt.fignum_exists(self.figure_kw['num'])
@property
@with_plot_context
def fig(self):
"""The figure hosting the plots."""
if self.is_fig_open():
return self._fig
elif plt.fignum_exists(num := self.figure_kw['num']):
plt.figure(num).clear()
try:
self._fig = plt.figure(**self.figure_kw)
......@@ -118,8 +140,8 @@ class PlotManager:
self.destroy_axes()
self.update_line_attrs(self.plots_to_draw, self.lines_to_draw, self.shown, stale=True)
# If the window is closed, remove the figure from the cache so that it can be recreated and
# stop the timer to delete any remaining callbacks
# If the window is closed, remove the figure from the cache so that it can be
# recreated and stop the timer to delete any remaining callbacks
self._fig.canvas.mpl_connect('close_event', on_close)
self.setup_figure()
......@@ -162,6 +184,11 @@ class PlotManager:
def plots_to_draw(self) -> Tuple[str, ...]:
return tuple(compress(self.PLOT_TYPES, [True, self.plot_cumulative, self.plot_timetrace]))
@property
@with_plot_context
def prop_cycle(self) -> Cycler:
return self._prop_cycle or plt.rcParams['axes.prop_cycle']
@property
def plot_context(self) -> ContextManager:
if self.plot_style is not None:
......@@ -186,7 +213,20 @@ class PlotManager:
@property
def plot_cumulative(self) -> bool:
"""If the cumulative (integrated) PSD or spectrum is plotted on a subplot."""
r"""If the cumulative (integrated) PSD or spectrum is plotted on
a subplot.
The cumulative PSD is given by
.. math::
\mathrm{RMS}_S(f)^2 = \int_{f_\mathrm{min}}^f\mathrm{d}f^\prime\,S(f^\prime)
with :math:`\mathrm{RMS}_S(f)` the root-mean-square of
:math:`S(f^\prime)` up to frequency :math:`f^\prime`.
If :attr:`plot_amplitude` is true, plot the square root of this,
i.e., :math:`\mathrm{RMS}_S(f)`.
"""
return self._plot_cumulative
@plot_cumulative.setter
......@@ -331,17 +371,6 @@ class PlotManager:
if self.is_fig_open():
self.setup_figure()
@property
def threaded_acquisition(self) -> bool:
"""Acquire data in a separate thread."""
return self._threaded_acquisition
@threaded_acquisition.setter
def threaded_acquisition(self, val: bool):
val = bool(val)
if val != self._threaded_acquisition:
self._threaded_acquisition = val
@property
def reference_spectrum(self) -> Optional[Tuple[int, str]]:
"""Spectrum taken as a reference for the dB scale.
......@@ -382,19 +411,24 @@ class PlotManager:
if 'timetrace_processed' in d)
return complex_raw_data or complex_processed_data
def main_plot(self, key, line_type):
x, y = self.get_freq_data(key, line_type, self.plot_dB_scale)
d = self.lines[key]['main'][line_type]
@with_plot_context
def plot_or_set_props(self, x, y, key, plot_type, line_type):
d = self.lines[key][plot_type][line_type]
if line := d['line']:
line.set_data(x, y)
line.set_color(self.line_props(key[0], d)['color'])
line.set_alpha(self.line_props(key[0], d)['alpha'])
line.set_zorder(self.line_props(key[0], d)['zorder'])
else:
line, = self.axes['main'][line_type].plot(x, y, **self.line_props(key[0], d))
self.update_line_attrs(['main'], [line_type], [key], stale=False, line=line)
line, = self.axes[plot_type][line_type].plot(x, y, **self.line_props(key[0], d))
self.update_line_attrs([plot_type], [line_type], [key], stale=False, line=line)
@with_plot_context
def main_plot(self, key, line_type):
x, y = self.get_freq_data(key, line_type, self.plot_dB_scale)
self.plot_or_set_props(x, y, key, 'main', line_type)
@with_plot_context
def cumulative_plot(self, key, line_type):
# y is the power irrespective of whether self.plot_amplitude is True or not.
# This means that if the latter is True, this plot shows the cumulative RMS,
......@@ -409,44 +443,28 @@ class PlotManager:
if self.plot_amplitude:
y = np.sqrt(y)
if self.plot_cumulative_normalized:
y = (y - y.min()) / y.ptp()
d = self.lines[key]['cumulative'][line_type]
if line := d['line']:
line.set_data(x, y)
line.set_color(self.line_props(key[0], d)['color'])
line.set_alpha(self.line_props(key[0], d)['alpha'])
line.set_zorder(self.line_props(key[0], d)['zorder'])
else:
line, = self.axes['cumulative'][line_type].plot(x, y, **self.line_props(key[0], d))
self.update_line_attrs(['cumulative'], [line_type], [key], stale=False, line=line)
y = (y - y.min()) / np.ptp(y)
self.plot_or_set_props(x, y, key, 'cumulative', line_type)
@with_plot_context
def time_plot(self, key, line_type):
y = self._data[key][f'timetrace_{line_type}'][-1]
if np.iscomplexobj(y):
y = np.abs(y)
x = np.arange(y.size) / self._data[key]['settings']['fs']
self.plot_or_set_props(x, y, key, 'time', line_type)
d = self.lines[key]['time'][line_type]
if line := d['line']:
line.set_data(x, y)
line.set_color(self.line_props(key[0], d)['color'])
line.set_alpha(self.line_props(key[0], d)['alpha'])
line.set_zorder(self.line_props(key[0], d)['zorder'])
else:
line, = self.axes['time'][line_type].plot(x, y, **self.line_props(key[0], d))
self.update_line_attrs(['time'], [line_type], [key], stale=False, line=line)
@with_plot_context
def setup_figure(self):
gs = gridspec.GridSpec(2 + self.plot_cumulative + self.plot_timetrace, 1, figure=self.fig,
**self.gridspec_kw)
with self.plot_context:
self.setup_main_axes(gs)
self.setup_cumulative_axes(gs)
self.setup_time_axes(gs)
self.destroy_unused_axes()
self.update_figure()
@with_plot_context
def setup_main_axes(self, gs: gridspec.GridSpec):
if self.axes['main']['processed'] is None:
self.axes['main']['processed'] = self.fig.add_subplot(gs[:2], **self.subplot_kw)
......@@ -476,6 +494,7 @@ class PlotManager:
)
self.set_subplotspec('main', gs[:2])
@with_plot_context
def setup_cumulative_axes(self, gs: gridspec.GridSpec):
if self.plot_cumulative:
if self.axes['cumulative']['processed'] is None:
......@@ -503,6 +522,7 @@ class PlotManager:
)
self.set_subplotspec('cumulative', gs[2])
@with_plot_context
def setup_time_axes(self, gs: gridspec.GridSpec):
if self.plot_timetrace:
if self.axes['time']['processed'] is None:
......@@ -527,8 +547,8 @@ class PlotManager:
try:
self.axes[plot][line].remove()
self.axes[plot][line] = None
except AttributeError:
# Ax None
except (AttributeError, NotImplementedError):
# Ax None / Artist dead
continue
def destroy_unused_axes(self):
......@@ -552,6 +572,7 @@ class PlotManager:
# Line None
continue
@with_plot_context
def update_figure(self):
# Flush out all idle events, necessary for some reason in sequential mode
self.fig.canvas.flush_events()
......@@ -585,6 +606,7 @@ class PlotManager:
self.fig.canvas.draw_idle()
self.fig.canvas.flush_events()
@with_plot_context
def update_lines(self):
for key in self.shown:
for plot in self.plots_to_draw:
......@@ -607,6 +629,7 @@ class PlotManager:
for line in self.lines_to_draw:
self.axes[plot][line].set_subplotspec(gs)
@with_plot_context
def set_xlims(self):
# Frequency-axis plots
right = max((
......@@ -644,6 +667,7 @@ class PlotManager:
self.axes['time']['processed'].relim(visible_only=True)
self.axes['time']['processed'].autoscale(enable=True, axis='x', tight=True)
@with_plot_context
def set_ylims(self):
if not self.shown:
return
......@@ -675,6 +699,7 @@ class PlotManager:
# If bottom = top
self.axes[plot][line].set_ylim(bottom, top)
@with_plot_context
def set_xscales(self):
if (
# If daq returns complex data, the spectrum will have negative freqs
......@@ -762,38 +787,40 @@ def _ax_unit(amplitude: bool, density: bool, integrated: bool, cumulative_normal
dB: bool, unit: str) -> str:
if integrated and cumulative_normalized:
return ' (a.u.)'
if dB:
elif dB:
unit = 'dB'
power = '$^2$' if not amplitude and not dB else ''
hz_mul = 'Hz' if integrated and not density else ''
if density and not integrated:
return ' ({unit}{power}{hz_mul}{hz_div})'.format(
unit=unit,
power=power,
hz_mul=hz_mul,
hz_div=r'/$\sqrt{\mathrm{Hz}}$' if amplitude and density else r'/$\mathrm{Hz}$'
)
return ' ({unit}{power}{hz_mul})'.format(
unit=unit,
power=power,
hz_mul=hz_mul,
)
def _ax_label(amplitude: bool, integrated: bool, dB: bool, reference: _keyT) -> str:
if not dB:
return '{a}{b}S{c}(f{d}){e}'.format(
a=r'$\sqrt{{' if amplitude else '$',
b=r'\int_0^f\mathrm{{d}}f^\prime ' if integrated else '',
c='^2' if integrated and amplitude else '',
d=r'^\prime' if integrated else '',
e='}}$' if amplitude else '$'
)
return '{a}{b} relative to index {c}'.format(
a='integrated ' if integrated else '',
b='amplitude' if amplitude else 'power',
c=reference[0]
).capitalize()
elif not amplitude:
unit = unit + '$^2$'
labels = {
# (amplitude, density, integrated)
(False, False, False): f' ({unit})',
(False, False, True): f' ({unit}' + r'$\mathrm{Hz}$)',
(False, True, False): f' ({unit}' + r'$/\mathrm{Hz}$)',
(False, True, True): f' ({unit})',
(True, False, False): f' ({unit})',
(True, False, True): f' ({unit}' + r'$\sqrt{\mathrm{Hz}}$)',
(True, True, False): f' ({unit}' + r'$/\sqrt{\mathrm{Hz}}$)',
(True, True, True): f' ({unit})',
}
return labels[(amplitude, density, integrated)]
def _ax_label(amplitude: bool, integrated: bool, dB: bool, reference: Union[_keyT, None]) -> str:
labels = {
(False, False, False): '$S(f)$',
(False, True, False): r'$\sqrt{S(f)}$',
(False, False, True): r'$\mathrm{RMS}_S(f)^2$',
(False, True, True): r'$\mathrm{RMS}_S(f)$'
}
if dB:
labels |= {
(True, False, False): f'Pow. rel. to key {reference[0]}',
(True, True, False): f'Amp. rel. to key {reference[0]}',
(True, False, True): f'Int. pow. rel. to key {reference[0]}',
(True, True, True): f'Int. amp. rel. to key {reference[0]}'
}
return labels[(dB, amplitude, integrated)]
def _asinh_scale_maybe() -> Literal['asinh', 'linear']:
......
......@@ -3,6 +3,7 @@ import inspect
import os
import platform
import shelve
import sys
import warnings
from datetime import datetime
from pathlib import Path
......@@ -21,6 +22,7 @@ from matplotlib.figure import Figure
from matplotlib.legend import Legend
from qutil import io, misc
from qutil.functools import cached_property, chain, partial
from qutil.io import AsyncDatasaver
from qutil.itertools import count
from qutil.plotting import is_using_mpl_gui_backend
from qutil.plotting import live_view
......@@ -197,6 +199,10 @@ class Spectrometer:
threaded_acquisition : bool, default True
Acquire data in a separate thread. This keeps the plot window
responsive while acquisition is running.
blocking_acquisition : bool, default False
Block the interpreter while acquisition is running. This might
prevent concurrency errors when running a measurement script
that performs multiple acquisitions or plot actions.
prop_cycle : cycler.Cycler
A property cycler for styling the plotted lines.
play_sound : bool, default False
......@@ -374,7 +380,7 @@ class Spectrometer:
_to_expose = ('fig', 'ax', 'ax_raw', 'leg', 'plot_raw', 'plot_timetrace', 'plot_cumulative',
'plot_negative_frequencies', 'plot_absolute_frequencies', 'plot_amplitude',
'plot_density', 'plot_cumulative_normalized', 'plot_style', 'plot_dB_scale',
'threaded_acquisition', 'reference_spectrum', 'processed_unit')
'reference_spectrum', 'processed_unit')
# type checkers
fig: Figure
......@@ -410,7 +416,7 @@ class Spectrometer:
plot_update_mode: Optional[Literal['fast', 'always', 'never']] = None,
plot_dB_scale: bool = False, play_sound: bool = False,
audio_amplitude_normalization: Union[Literal["single_max"], float] = "single_max",
threaded_acquisition: bool = True,
threaded_acquisition: bool = True, blocking_acquisition: bool = False,
purge_raw_data: bool = False, prop_cycle=None, savepath: _pathT = None,
relative_paths: bool = True, compress: bool = True, raw_unit: str = 'V',
processed_unit: Optional[str] = None, figure_kw: Optional[Mapping] = None,
......@@ -420,6 +426,8 @@ class Spectrometer:
self._data: Dict[Tuple[int, str], Dict] = {}
self._savepath: Optional[Path] = None
self._acquiring = False
self._stop_event = Event()
self._datasaver = AsyncDatasaver('dill', compress)
self.daq = daq
self.procfn = chain(*procfn) if np.iterable(procfn) else chain(procfn or Id)
......@@ -427,13 +435,14 @@ class Spectrometer:
if savepath is None:
savepath = Path.home() / 'python_spectrometer' / datetime.now().strftime('%Y-%m-%d')
self.savepath = savepath
self.compress = compress
if plot_update_mode is not None:
warnings.warn('plot_update_mode is deprecated and has no effect', DeprecationWarning)
if purge_raw_data:
warnings.warn('Enabling purge raw data might break some plotting features!',
UserWarning)
self.purge_raw_data = purge_raw_data
self.threaded_acquisition = threaded_acquisition
self.blocking_acquisition = blocking_acquisition
if psd_estimator is None:
psd_estimator = {}
......@@ -455,10 +464,9 @@ class Spectrometer:
plot_cumulative, plot_negative_frequencies,
plot_absolute_frequencies, plot_amplitude,
plot_density, plot_cumulative_normalized,
plot_style, plot_dB_scale, threaded_acquisition,
prop_cycle, raw_unit, processed_unit,
uses_windowed_estimator, complex_data, figure_kw,
subplot_kw, gridspec_kw, legend_kw)
plot_style, plot_dB_scale, prop_cycle, raw_unit,
processed_unit, uses_windowed_estimator, complex_data,
figure_kw, subplot_kw, gridspec_kw, legend_kw)
self._audio_amplitude_normalization = audio_amplitude_normalization
self._play_sound = play_sound
......@@ -611,13 +619,9 @@ class Spectrometer:
keys = self.keys()
return ' - ' + '\n - '.join((str(key) for key in sorted(self.keys()) if key in keys))
@mock.patch.multiple('pickle', Unpickler=dill.Unpickler, Pickler=dill.Pickler)
def _savefn(self, file: _pathT, **kwargs):
file = io.check_path_length(self._resolve_path(file))
if self.compress:
np.savez_compressed(str(file), **_to_native_types(kwargs))
else:
np.savez(str(file), **_to_native_types(kwargs))
def _save(self, file: _pathT, **kwargs):
self._datasaver(io.check_path_length(self._resolve_path(file)),
**_to_native_types(kwargs))
@classmethod
def _make_kwargs_compatible(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
......@@ -688,7 +692,7 @@ class Spectrometer:
self._data[key]['S_processed'] = np.mean(self._data[key]['S_processed'], axis=0)[None]
self._data[key].update(measurement_metadata=metadata)
self._savefn(self._data[key]['filepath'], **self._data[key])
self._save(self._data[key]['filepath'], **self._data[key])
def _take_threaded(self, progress: bool, key: _keyT, n_avg: int, **settings):
"""Acquire data in a separate thread.
......@@ -730,7 +734,7 @@ class Spectrometer:
iterator = self.daq.acquire(n_avg=n_avg, **settings)
for i in progressbar(count(), disable=not progress, total=n_avg,
desc=f'Acquiring {n_avg} spectra with key {key}'):
if stop_flag.is_set():
if self._stop_event.is_set():
print('Acquisition interrupted.')
break
try:
......@@ -747,13 +751,13 @@ class Spectrometer:
self._acquiring = False
def on_close(event):
stop_flag.set()
self._stop_event.set()
self._acquiring = False
INTERACTIVE = is_using_mpl_gui_backend(self.fig)
self._stop_event.clear()
queue = Queue()
stop_flag = Event()
thread = Thread(target=acquire, daemon=True)
thread.start()
......@@ -830,6 +834,8 @@ class Spectrometer:
self._take_threaded(progress, key, **settings)
else:
self._take_sequential(progress, key, **settings)
if self.blocking_acquisition:
self.block_until_ready()
take.__doc__ = (take.__doc__.replace(8*' ', '')
+ '\n\nDAQ Parameters'
......@@ -1171,8 +1177,8 @@ class Spectrometer:
live_view_kw = {}
live_view_kw.setdefault('blocking_queue', True)
live_view_kw.setdefault('autoscale', 'c')
live_view_kw.setdefault('autoscale_interval_ms', None)
live_view_kw.setdefault('style', self.plot_style)
live_view_kw.setdefault(
'plot_legend',
'upper right' if np.issubdtype(self.daq.DTYPE, np.complexfloating) else False
......@@ -1182,14 +1188,15 @@ class Spectrometer:
plot_line=True, xlim=xlim, ylim=(0, (max_rows - 1) * T),
xlabel='$f$', ylabel='$t$', clabel='$S(f)$',
units={'x': 'Hz', 'y': 's', 'c': self.processed_unit + r'$/\sqrt{{Hz}}$'},
img_kw=dict(norm=colors.LogNorm(vmin=0.1, vmax=10), cmap='Blues')
img_kw=dict(norm=colors.LogNorm(vmin=0.1, vmax=10))
)
freq_kw = {'xscale': freq_xscale} | live_view_kw | fixed_kw
if any(freq_kw[key] != val for key, val in live_view_kw.items()):
freq_kw = {'autoscale': 'c', 'xscale': freq_xscale, 'img_kw': {'cmap': 'Blues'}}
freq_kw = _merge_recursive(freq_kw, _merge_recursive(live_view_kw, fixed_kw))
if not _dict_is_subset(live_view_kw, freq_kw):
warnings.warn('Overrode some keyword arguments for FrequencyLiveView', UserWarning)
# The view(s) get data from these queues and subsequently put them into their own
stop_event = Event()
self._stop_event.clear()
get_queues = [LifoQueue(maxsize=int(live_view_kw['blocking_queue']))]
views = [get_live_view(FrequencyLiveView, put_frequency_data, get_queue=get_queues[0],
**freq_kw)]
......@@ -1198,8 +1205,9 @@ class Spectrometer:
fixed_kw = dict(xlim=(0, T), xlabel='$t$', ylabel='Amplitude',
n_lines=2 if np.issubdtype(self.daq.DTYPE, np.complexfloating) else 1,
units={'x': 's', 'y': self.processed_unit})
time_kw = live_view_kw | fixed_kw
if any(time_kw[key] != val for key, val in live_view_kw.items()):
time_kw = {'autoscale': 'y'}
time_kw = _merge_recursive(time_kw, _merge_recursive(live_view_kw, fixed_kw))
if not _dict_is_subset(live_view_kw, time_kw):
warnings.warn('Overrode some keyword arguments for TimeLiveView', UserWarning)
get_queues.append(LifoQueue(maxsize=int(live_view_kw['blocking_queue'])))
......@@ -1210,23 +1218,24 @@ class Spectrometer:
# thread (and stop the other view)
watcher_threads = [
Thread(target=monitor_event,
args=(views[0].stop_event, stop_event, views[1].stop_event,
args=(views[0].stop_event, self._stop_event, views[1].stop_event,
# Need to add the proxy's close_event for headless tests
views[1].close_event if in_process else None),
daemon=True),
Thread(target=monitor_event,
args=(views[1].stop_event, stop_event, views[0].stop_event,
args=(views[1].stop_event, self._stop_event, views[0].stop_event,
views[0].close_event if in_process else None),
daemon=True),
]
else:
watcher_threads = [
Thread(target=monitor_event, args=(views[0].stop_event, stop_event), daemon=True)
Thread(target=monitor_event, args=(views[0].stop_event, self._stop_event),
daemon=True)
]
# The feeder thread actually performs the data acquisition and distributes it into the
# get_queues until stop_event is set.
feeder_thread = Thread(target=acquire_and_feed, args=[stop_event] + get_queues,
feeder_thread = Thread(target=acquire_and_feed, args=[self._stop_event] + get_queues,
daemon=True)
# Finally, start all the threads
......@@ -1261,6 +1270,10 @@ class Spectrometer:
while self.acquiring and not exceeded:
self.fig.canvas.start_event_loop(self._plot_manager.TIMER_INTERVAL * 1e-3)
def abort_acquisition(self):
"""Abort the current acquisition."""
self._stop_event.set()
def reprocess_data(self,
*comment_or_index: _keyT,
save: Literal[False, True, 'overwrite'] = False,
......@@ -1298,7 +1311,7 @@ class Spectrometer:
data['filepath'] = io.query_overwrite(data['filepath'])
else:
data['filepath'] = self._get_new_file(comment=data['comment'])
self._savefn(data['filepath'], **data)
self._datasaver(data['filepath'], **data)
self._data[key] = data
self._plot_manager.update_line_attrs(self._plot_manager.plots_to_draw,
......@@ -1403,10 +1416,8 @@ class Spectrometer:
data['filepath'] = newfile if relative_paths else savepath / newfile
newfile = io.query_overwrite(io.check_path_length(savepath / newfile))
if compress:
np.savez_compressed(savepath / newfile, **_to_native_types(data))
else:
np.savez(savepath / newfile, **_to_native_types(data))
with AsyncDatasaver('dill', compress) as datasaver:
datasaver.save_sync(savepath / newfile, **_to_native_types(data))
if newfile == oldfile:
# Already 'deleted' (overwrote) the old file
......@@ -1458,7 +1469,7 @@ class Spectrometer:
'plot_raw', 'plot_timetrace', 'plot_cumulative',
'plot_negative_frequencies', 'plot_absolute_frequencies',
'plot_amplitude', 'plot_density', 'plot_cumulative_normalized',
'plot_style', 'plot_update_mode', 'plot_dB_scale', 'compress']
'plot_style', 'plot_update_mode', 'plot_dB_scale']
plot_manager_attrs = ['reference_spectrum', 'prop_cycle', 'raw_unit', 'processed_unit']
with shelve.open(str(file), protocol=protocol) as db:
# Constructor args
......@@ -1577,6 +1588,8 @@ class Spectrometer:
key = (self._index, data['comment'])
self._data[key] = data
# Make sure the figure is live
self._plot_manager.update_figure()
self._plot_manager.add_new_line_entry(key)
if show:
self.show(key, color=color)
......@@ -1626,7 +1639,14 @@ def _load_spectrum(file: _pathT) -> Dict[str, Any]:
def __exit__(self, exc_type, exc_val, exc_tb):
delattr(io, 'JanewayWindowsPath')
with np.load(file, allow_pickle=True) as fp, monkey_patched_io():
# Patch modules for data saved before move to separate package
renamed_modules = {'qutil.measurement.spectrometer.daq.settings': daq_settings}
with (
mock.patch.dict(sys.modules, renamed_modules),
np.load(file, allow_pickle=True) as fp,
monkey_patched_io()
):
data = {}
for key, val in fp.items():
try:
......@@ -1698,6 +1718,78 @@ def _from_native_types(data: Dict[str, Any]) -> Dict[str, Any]:
return data
def _merge_recursive(orig: dict, upd: dict, inplace: bool = False) -> dict:
"""Recursively update the original dictionary 'orig' with the
values from 'upd'.
If both dictionaries have a value for the same key and those values
are also dictionaries, the function is called recursively on those
nested dictionaries.
Parameters
----------
orig :
The dictionary that is updated.
upd :
The dictionary whose values are used for updating.
inplace :
Merge or update.
Returns
-------
dict :
The updated original dictionary.
"""
if not inplace:
orig = copy.copy(orig)
for key, upd_value in upd.items():
orig_value = orig.get(key)
# If both the existing value and the update value are dictionaries,
# merge them recursively.
if isinstance(orig_value, dict) and isinstance(upd_value, dict):
# Since `orig[key]` is already part of the copied version
# if not inplace, we can update it in place.
_merge_recursive(orig_value, upd_value, inplace=True)
else:
# Otherwise, replace or add the value from `upd` into `orig`.
orig[key] = upd_value
return orig
def _dict_is_subset(source: dict, target: dict) -> bool:
"""
Checks recursively whether the nested dict *source* is a subset of
*target*.
Parameters
----------
source :
The dictionary whose keys must be present.
target :
The dictionary to check for the required keys.
"""
for key, source_value in source.items():
# Check if the key exists in the target dictionary.
if key not in target:
return False
target_value = target[key]
# If both values are dictionaries, check recursively.
if isinstance(source_value, dict):
if not isinstance(target_value, dict):
return False
if not _dict_is_subset(source_value, target_value):
return False
else:
# For non-dict values, check equality.
if source_value != target_value:
return False
return True
class ReadonlyError(Exception):
"""Indicates a :class:`Spectrometer` object is read-only."""
pass
......@@ -466,16 +466,26 @@ class DAQSettings(dict):
fs = self._domain_df.next_smallest(fs / self['nperseg']) * self['nperseg']
if not self._isclose(fs, fs_prev):
self['df'] = self._domain_df.next_closest(fs_prev / self['nperseg'])
self._make_compatible_nperseg(ceil(fs_prev / self['df']))
self._make_compatible_nperseg(ceil(self._domain_df.round(fs_prev / self['df'])))
return self._to_allowed_fs(fs_prev)
# Constraints on fs itself
if not isinf(df) and not (fs / df) % 1:
if not isinf(df) and not self._domain_fs.round(fs / df) % 1:
# fs might be due to ceil-ing when inferring nperseg. Use next_closest
fs = self._domain_fs.next_closest(fs)
else:
fs = self._domain_fs.next_largest(fs)
if not self._isclose(fs, fs_prev):
self._make_compatible_fs(fs)
return fs
# Finally, as a last resort test if the parameters match. If not, try to adjust nperseg
if not isinf(df) and 'nperseg' in self:
fs = df / self['nperseg']
if not self._isclose(fs, fs_prev):
self['fs'] = self._domain_fs.next_closest(fs_prev)
self._make_compatible_nperseg(
self._domain_nperseg.next_closest(self['fs'] / df)
)
return self['fs']
return fs
......@@ -490,7 +500,7 @@ class DAQSettings(dict):
if not self._isclose(df, df_prev):
self['fs'] = self._domain_fs.next_closest(df_prev * self['nperseg'])
# Use df instead of df_prev here because we preferentially adjust df over fs or nperseg
self._make_compatible_nperseg(ceil(self['fs'] / df))
self._make_compatible_nperseg(ceil(self._domain_fs.round(self['fs'] / df)))
return self._to_allowed_df(df)
if not isinf(fs := self.get('fs', self.get('f_max', inf) * 2)):
# Constraints on nperseg constrain df
......@@ -503,6 +513,16 @@ class DAQSettings(dict):
df = self._domain_df.next_smallest(df)
if not self._isclose(df, df_prev):
self._make_compatible_df(df)
return df
# Finally, as a last resort test if the parameters match. If not, try to adjust nperseg
if not isinf(fs) and 'nperseg' in self:
df = fs / self['nperseg']
if not self._isclose(df, df_prev):
self['df'] = self._domain_df.next_closest(df_prev)
self._make_compatible_nperseg(
self._domain_nperseg.next_closest(fs / self['df'])
)
return self['df']
return df
......@@ -522,14 +542,14 @@ class DAQSettings(dict):
df = self.get('df', self.get('f_min', inf))
if not isinf(df):
# Constraints on fs constrain nperseg through df/f_min
nperseg = ceil(self._domain_fs.next_largest(df * nperseg) / df)
nperseg = ceil(self._domain_df.round(self._domain_fs.next_largest(df * nperseg) / df))
if nperseg != nperseg_prev:
self['fs'] = self._domain_fs.next_largest(fs if not isinf(fs) else df * nperseg_prev)
self._make_compatible_df(self['fs'] / nperseg_prev)
return self._to_allowed_nperseg(nperseg_prev)
if not isinf(fs):
# Constraints on df constrain nperseg through fs/f_max
nperseg = ceil(fs / self._domain_df.next_smallest(fs / nperseg))
nperseg = ceil(self._domain_fs.round(fs / self._domain_df.next_smallest(fs / nperseg)))
if nperseg != nperseg_prev:
self['df'] = self._domain_df.next_closest(df if not isinf(df) else fs * nperseg_prev)
self._make_compatible_fs(self['df'] * nperseg_prev)
......@@ -538,6 +558,14 @@ class DAQSettings(dict):
nperseg = self._domain_nperseg.next_largest(nperseg)
if nperseg != nperseg_prev:
self._make_compatible_nperseg(nperseg)
return nperseg
# Finally, as a last resort test if the parameters match. If not, try to adjust df
if not isinf(df) and not isinf(fs):
nperseg = ceil(self._domain_fs.round(fs / df))
if nperseg != nperseg_prev:
self['nperseg'] = self._domain_nperseg.next_closest(nperseg_prev)
self._make_compatible_df(self._domain_df.next_closest(self['fs'] / self['nperseg']))
return self['nperseg']
return nperseg
......@@ -630,7 +658,7 @@ class DAQSettings(dict):
def _infer_nperseg(self, default: bool = False) -> int | None:
# user-set fs or df take precedence over noverlap
if 'fs' in self and 'df' in self:
return ceil(self['fs'] / self['df'])
return ceil(self._domain_fs.round(self['fs'] / self['df']))
if 'n_pts' in self:
if 'noverlap' in self and 'n_seg' in self:
return int((self['n_pts'] + (self['n_seg'] - 1) * self['noverlap'])
......@@ -648,7 +676,7 @@ class DAQSettings(dict):
if not isinf(df) and not isinf(fs):
# In principle this should be self._domain_nperseg.next_largest(), but that recurses
# infinitely. So we do what we can.
return min(ceil(fs / df), self._upper_bound_nperseg())
return min(ceil(self._domain_fs.round(fs / df)), self._upper_bound_nperseg())
return None
def _infer_noverlap(self, default: bool = False) -> int | None:
......@@ -775,7 +803,7 @@ class DAQSettings(dict):
return self['nperseg']
if (nperseg := self._infer_nperseg()) is not None:
return nperseg
return self.setdefault('nperseg', ceil(self.fs / self.df))
return self.setdefault('nperseg', ceil(self._domain_fs.round(self.fs / self.df)))
@interdependent_daq_property
def noverlap(self) -> int:
......
......@@ -58,7 +58,7 @@ import sys
import time
import warnings
from abc import ABC
from typing import Any, Dict, Mapping, Optional, Type, Union
from typing import Any, Dict, Mapping, Optional, Type
import numpy as np
from packaging import version
......
......@@ -39,10 +39,14 @@ def test_daq_settings_warnings():
s = DAQSettings(nperseg=1000, df=0.999)
s.fs = 1000
with pytest.warns(UserWarning, match='Need to change fs from 1001 to 1000.'):
with pytest.warns(UserWarning, match='Need to change df from 1 to 1.001'):
s = DAQSettings(fs=1001, df=1)
s.nperseg = 1000
with pytest.warns(UserWarning, match='Need to change nperseg from 1000 to 1001'):
s = DAQSettings(fs=1001, nperseg=1000)
s.df = 1
def test_daq_settings_exceptions():
......@@ -134,6 +138,10 @@ def test_mfli_daq_settings(mock_zi_daq):
t = mock_zi_daq.DAQSettings(**s.to_consistent_dict())
t.to_consistent_dict()
s = mock_zi_daq.DAQSettings(f_min=1e1, f_max=1e5)
t = mock_zi_daq.DAQSettings(s.to_consistent_dict())
t.to_consistent_dict()
def test_reproducibility():
"""Test if there are no exceptions or infinite recursions for a
......
......@@ -2,6 +2,7 @@ import os
import pathlib
import random
import string
import time
from tempfile import mkdtemp
import pytest
......@@ -79,7 +80,8 @@ def spectrometer(monkeypatch, relative_paths: bool, threaded_acquisition: bool):
@pytest.fixture
def serialized(spectrometer: Spectrometer):
def serialized(done_saving: Spectrometer):
spectrometer = done_saving
stem = ''.join(random.choices(string.ascii_letters, k=10))
try:
......@@ -96,8 +98,18 @@ def serialized(spectrometer: Spectrometer):
remove_file_if_exists(spectrometer.savepath / f'{stem}{ext}')
def test_saving(spectrometer: Spectrometer, relative_paths: bool):
@pytest.fixture
def done_saving(spectrometer: Spectrometer):
while spectrometer._datasaver.jobs:
time.sleep(0.1)
return spectrometer
def test_saving(done_saving: Spectrometer, relative_paths: bool):
spectrometer = done_saving
assert spectrometer.savepath.exists()
for file in spectrometer.files:
if relative_paths:
assert os.path.exists(spectrometer.savepath / file)
......@@ -105,16 +117,14 @@ def test_saving(spectrometer: Spectrometer, relative_paths: bool):
assert os.path.exists(file)
def test_serialization(spectrometer: Spectrometer):
spectrometer.serialize_to_disk('blub')
def test_serialization(spectrometer: Spectrometer, serialized: pathlib.Path):
exts = ['_files.txt']
if (spectrometer.savepath / 'blub').is_file():
assert os.path.exists(spectrometer.savepath / 'blub')
if serialized.is_file():
assert os.path.exists(serialized)
else:
exts.extend(['.bak', '.dat', '.dir'])
for ext in exts:
assert os.path.exists(spectrometer.savepath / f'blub{ext}')
assert os.path.exists(serialized.with_name(serialized.name + ext))
def test_deserialization(serialized: pathlib.Path):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment