import numpy as np
import numba
import strax
export, __all__ = strax.exporter()
[docs]
@export
def split_peaks(
peaks,
hits,
records,
rlinks,
to_pe,
algorithm="local_minimum",
data_type="peaks",
n_top_channels=0,
store_data_top=False,
store_data_start=False,
**kwargs,
):
"""Return peaks split according to algorithm, with waveforms summed and widths computed.
Note:
Can also be used for hitlets splitting with local_minimum
splitter. Just put hitlets instead of peaks.
:param peaks: Original peaks. Sum waveform must have been built
and properties must have been computed (if you use them)
:param hits: Hits found in records. (or None in case of hitlets
splitting.)
:param records: Records from which peaks were built
:param rlinks: strax.record_links for given records
(or None in case of hitlets splitting.)
:param to_pe: ADC to PE conversion factor array (of n_channels)
:param algorithm: 'local_minimum' or 'natural_breaks'.
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
sum_wavefrom or get_hitlets_data to compute the waveform of
the new split peaks/hitlets.
:param n_top_channels: Number of top array channels.
:param result_dtype: dtype of the result.
:param store_data_top: Boolean which indicates whether to store the top array
waveform in the peak.
:param store_data_start: Boolean which indicates whether to store the first samples of the
waveform in the peak.
Any other options are passed to the algorithm.
"""
splitter = dict(local_minimum=LocalMinimumSplitter, natural_breaks=NaturalBreaksSplitter)[
algorithm
]()
data_type_is_not_supported = data_type not in ("hitlets", "peaks")
if data_type_is_not_supported:
raise TypeError(f'Data_type "{data_type}" is not supported.')
return splitter(
peaks,
hits,
records,
rlinks,
to_pe,
data_type,
n_top_channels=n_top_channels,
store_data_top=store_data_top,
store_data_start=store_data_start,
**kwargs,
)
NO_MORE_SPLITS = -9999999
class PeakSplitter:
"""Split peaks into more peaks based on arbitrary algorithm.
:param peaks: Original peaks. Sum waveform must have been built and properties must have been
computed (if you use them).
:param records: Records from which peaks were built.
:param rlinks: strax.record_links for given records.
:param to_pe: ADC to PE conversion factor array (of n_channels).
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use sum_waveform or
get_hitlets_data to compute the waveform of the new split peaks/hitlets.
:param do_iterations: maximum number of times peaks are recursively split.
:param min_area: Minimum area to do split. Smaller peaks are not split.
:param n_top_channels: Number of top array channels. The function find_split_points(),
implemented in each subclass defines the algorithm, which takes in a peak's waveform and
returns the index to split the peak at, if a split point is found. Otherwise NO_MORE_SPLITS
is returned and the peak is left as is.
:param store_data_top: Boolean which indicates whether to store the top array waveform in the
peak.
:param store_data_start: Boolean which indicates whether to store the first samples of the
waveform in the peak.
"""
find_split_args_defaults: tuple
def __call__(
self,
peaks,
hits,
records,
rlinks,
to_pe,
data_type,
do_iterations=1,
min_area=0,
n_top_channels=0,
store_data_top=False,
store_data_start=False,
**kwargs,
):
if not len(records) or not len(peaks) or not do_iterations:
return peaks
# Build the *args tuple for self.find_split_points from kwargs
# since numba doesn't support **kwargs
args_options = []
for i, (k, value) in enumerate(self.find_split_args_defaults):
if k in kwargs:
value = kwargs[k]
if k == "threshold":
# The 'threshold' option is a user-specified function
value = value(peaks)
args_options.append(value)
args_options = tuple(args_options)
# Check for spurious options
argnames = [k for k, _ in self.find_split_args_defaults]
for k in kwargs:
if k not in argnames:
raise TypeError(f"Unknown argument {k} for {self.__class__}")
is_split = np.zeros(len(peaks), dtype=bool)
new_peaks = self._split_peaks(
# Numba doesn't like self as argument, but it's ok with functions...
split_finder=self.find_split_points,
peaks=peaks,
is_split=is_split,
orig_dt=records[0]["dt"],
min_area=min_area,
args_options=tuple(args_options),
result_dtype=peaks.dtype,
)
if is_split.sum() != 0:
# Found new peaks: compute basic properties
if data_type == "peaks":
strax.sum_waveform(
new_peaks,
hits,
records,
rlinks,
to_pe,
n_top_channels=n_top_channels,
store_data_top=store_data_top,
store_data_start=store_data_start,
)
strax.compute_properties(new_peaks, n_top_channels=n_top_channels)
elif data_type == "hitlets":
# Add record fields here
new_peaks = strax.sort_by_time(
new_peaks
) # Hitlets are not necessarily sorted after splitting
new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe)
# ... and recurse (if needed)
new_peaks = self(
new_peaks,
hits,
records,
rlinks,
to_pe,
data_type,
do_iterations=do_iterations - 1,
min_area=min_area,
n_top_channels=n_top_channels,
store_data_top=store_data_top,
store_data_start=store_data_start,
**kwargs,
)
if np.any(new_peaks["length"] == 0):
raise ValueError("Want to add a new zero-length peak after splitting!")
peaks = strax.sort_by_time(np.concatenate([peaks[~is_split], new_peaks]))
return peaks
# this function can not be cached due to some unknown reasons
# maybe because the split_finder is a function and numba does not like it
@staticmethod
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.njit(nogil=True)
def _split_peaks(
split_finder,
peaks,
orig_dt,
is_split,
min_area,
args_options,
_result_buffer=None,
result_dtype=None,
):
"""Loop over peaks, pass waveforms to algorithm, construct new peaks if and where a split
occurs."""
new_peaks = _result_buffer
offset = 0
for p_i, p in enumerate(peaks):
if p["area"] < min_area:
continue
prev_split_i = 0
w = p["data"][: p["length"]]
for split_i, bonus_output in split_finder(w, p["dt"], p_i, *args_options):
if split_i == NO_MORE_SPLITS:
p["max_goodness_of_split"] = bonus_output
# although the iteration will end anyway afterwards:
continue
is_split[p_i] = True
r = new_peaks[offset]
r["time"] = p["time"] + prev_split_i * p["dt"]
r["channel"] = p["channel"]
# Set the dt to the original (lowest) dt first;
# this may change when the sum waveform of the new peak
# is computed
r["dt"] = orig_dt
r["length"] = (split_i - prev_split_i) * p["dt"] / orig_dt
# Too lazy to compute these
r["max_gap"] = -1
r["max_diff"] = -1
r["min_diff"] = -1
r["first_channel"] = -1
r["last_channel"] = -1
if r["length"] <= 0:
print(p["data"])
print(prev_split_i, split_i)
raise ValueError("Attempt to create invalid peak!")
offset += 1
if offset == len(new_peaks):
yield offset
offset = 0
prev_split_i = split_i
yield offset
@staticmethod
def find_split_points(w, dt, peak_i, *args_options):
"""This function is overwritten by LocalMinimumSplitter or LocalMinimumSplitter bare
PeakSplitter class is not implemented."""
raise NotImplementedError
class LocalMinimumSplitter(PeakSplitter):
"""Split peaks at significant local minima.
On either side of a split point, local maxima are required to be
- larger than minimum + min_height, AND
- larger than minimum * min_ratio
This is related to topographical prominence for mountains.
NB: Min_height is in pe/ns, NOT pe/bin!
"""
find_split_args_defaults = (("min_height", 0), ("min_ratio", 0))
@staticmethod
@numba.njit(nogil=True)
def find_split_points(w, dt, peak_i, min_height, min_ratio):
"""Yields indices of prominent local minima in w If there was at least one index, yields
len(w)-1 at the end."""
found_one = False
last_max = -99999999999999.9
min_since_max = 99999999999999.9
min_since_max_i = 0
for i, x in enumerate(w):
if x < min_since_max:
# New minimum since last max
min_since_max = x
min_since_max_i = i
if min(last_max, x) > max(min_since_max + min_height, min_since_max * min_ratio):
# Significant local minimum: tell caller,
# reset both max and min finder
yield min_since_max_i, 0.0
found_one = True
last_max = x
min_since_max = 99999999999999.9
min_since_max_i = i
if x > last_max:
# New max, reset minimum finder state
# Notice this is AFTER the split check,
# to accomodate very fast rising second peaks
last_max = x
min_since_max = 99999999999999.9
min_since_max_i = i
if found_one:
yield len(w), 0.0
yield NO_MORE_SPLITS, 0.0
class NaturalBreaksSplitter(PeakSplitter):
"""Split peaks according to (variations of) the natural breaks algorithm, i.e. such that the sum
squared difference from the mean is minimized.
Options:
- threshold: threshold to accept a split in the goodness of split value:
1 - (f(left) + f(right))/f(unsplit)
- normalize: if True, f is the variance. Otherwise, it is the
sum squared difference from the mean (i.e. unnormalized variance)
- split_low: if True, multiply the goodness of split value by one minus
the ratio between the waveform at the split point and the maximum in
the waveform. This prevent splits at high density points.
- filter_wing_width: if > 0, do a moving average filter (without shift)
on the waveform before the split_low computation.
The window will include the sample itself, plus filter_wing_width (or as
close as we can get to it given the peaks sampling) on either side.
"""
find_split_args_defaults = (
("threshold", None), # will be a numpy array of len(peaks)
("normalize", False),
("split_low", False),
("filter_wing_width", 0),
)
@staticmethod
@numba.njit(nogil=True)
def find_split_points(w, dt, peak_i, threshold, normalize, split_low, filter_wing_width):
gofs = natural_breaks_gof(
w, dt, normalize=normalize, split_low=split_low, filter_wing_width=filter_wing_width
)
max_i = np.argmax(gofs)
if gofs[max_i] > threshold[peak_i]:
yield max_i, 0.0
yield len(w) - 1, 0.0
yield NO_MORE_SPLITS, gofs[max_i]
[docs]
@export
@numba.njit(nogil=True, cache=True)
def natural_breaks_gof(w, dt, normalize=False, split_low=False, filter_wing_width=0):
"""Return natural breaks goodness of split/fit for the waveform w a sharp peak gives ~0, two
widely separate peaks ~1."""
left = sum_squared_deviations(w, normalize=normalize)
right = sum_squared_deviations(w[::-1], normalize=normalize)[::-1]
gof = 1 - (left + right) / left[-1]
if split_low:
# Adjust to prevent splits at high density points
filter_n = filter_wing_width // dt - 1
if filter_n > 0:
filtered_w = symmetric_moving_average(w, filter_n)
else:
filtered_w = w
gof *= 1 - filtered_w / filtered_w.max()
return gof
[docs]
@export
@numba.njit(nogil=True, cache=True)
def symmetric_moving_average(a, wing_width):
"""Return the moving average of a, over windows of length [2 * wing_width + 1] centered on each
sample.
(i.e. the window covers each sample itself, plus a 'wing' of width wing_width on either side)
"""
if wing_width == 0:
return a
n = len(a)
out = np.empty(n, dtype=a.dtype)
asum = a[:wing_width].sum()
count = wing_width
for i in range(len(a)):
# Index of the sample that just disappeared
# from the window
just_out = i - wing_width - 1
if just_out > 0:
count -= 1
asum -= a[just_out]
# Index of the sample that just appeared
# in the window
just_in = i + wing_width
if just_in < n:
count += 1
asum += a[just_in]
out[i] = asum / count
return out
@numba.njit(nogil=True, cache=True)
def sum_squared_deviations(waveform, normalize=False):
"""Return left-to-right result of an online sum-intra-class variance computation on the
waveform.
:param normalize: If True, divide by the total area, i.e. produce ordinary variance.
"""
mean = sum_weights = s = 0
result = np.zeros(len(waveform))
for i, w in enumerate(waveform):
# Negative weights can lead to odd results, so clip waveform
w = max(0, w)
sum_weights += w
if sum_weights == 0:
continue
mean_old = mean
mean = mean_old + (w / sum_weights) * (i - mean_old)
s += w * (i - mean_old) * (i - mean)
result[i] = s
if normalize:
result[i] /= sum_weights
return result