Source code for strax.processing.peak_merging

import math
import numba
import numpy as np

import strax
from strax.processing.general import _fully_contained_in, _fully_contained_in_sanity

export, __all__ = strax.exporter()


[docs] @export @numba.njit(cache=True, nogil=True) def gcd_of_array(values): """Return the GCD of all elements in the array.""" result = values[0] for i in range(1, len(values)): result = math.gcd(result, values[i]) return result
[docs] @export def merge_peaks( peaks, start_merge_at, end_merge_at, merged=None, max_buffer=int(1e5), ): """Merge specified peaks with their neighbors, return merged peaks. :param peaks: Record array of strax peak dtype. :param start_merge_at: Indices to start merge at :param end_merge_at: EXCLUSIVE indices to end merge at :param max_buffer: Maximum number of samples in the sum_waveforms and other waveforms of the resulting peaks (after merging). Peaks must be constructed based on the properties of constituent peaks, it being too time-consuming to revert to records/hits. """ new_peaks, endtime = _merge_peaks( peaks, start_merge_at, end_merge_at, merged=merged, max_buffer=max_buffer, ) # If the endtime was in the peaks we have to recompute it here # because otherwise it will stay set to zero due to the buffer if "endtime" in peaks.dtype.names: # here endtime != time + length * dt because of the downsampling # so we have to collect the endtime in _merge_peaks new_peaks["endtime"] = endtime return new_peaks
@numba.njit(cache=True, nogil=True) def _merge_peaks( peaks, start_merge_at, end_merge_at, merged=None, max_buffer=int(1e5), ): """Merge specified peaks with their neighbors, return merged peaks. :param peaks: Record array of strax peak dtype. :param start_merge_at: Indices to start merge at :param end_merge_at: EXCLUSIVE indices to end merge at :param max_buffer: Maximum number of samples in the sum_waveforms and other waveforms of the resulting peaks (after merging). Peaks must be constructed based on the properties of constituent peaks, it being too time-consuming to revert to records/hits. """ assert len(start_merge_at) == len(end_merge_at) if merged is not None and len(start_merge_at): assert len(merged) == len(peaks) assert len(merged) >= max(end_merge_at) if np.min(peaks["time"][1:] - strax.endtime(peaks)[:-1]) < 0: raise ValueError("Peaks not disjoint! You have to rewrite this function to handle this.") new_peaks = np.zeros(len(start_merge_at), dtype=peaks.dtype) endtime = np.zeros(len(start_merge_at), dtype=np.int64) # Do the merging. Could numbafy this to optimize, probably... buffer = np.zeros(max_buffer, dtype=np.float32) buffer_top = np.zeros(max_buffer, dtype=np.float32) for new_i, new_p in enumerate(new_peaks): new_p["min_diff"] = 2147483647 # inf of int32 sl = slice(start_merge_at[new_i], end_merge_at[new_i]) old_peaks = peaks[sl] # if merged is not None, we have to only take the merged peaks if merged is not None: if not np.any(merged[sl]): raise ValueError("Trying to merge zero peaks!") old_peaks = old_peaks[merged[sl]] common_dt = gcd_of_array(old_peaks["dt"]) first_peak, last_peak = old_peaks[0], old_peaks[-1] new_p["channel"] = first_peak["channel"] # The new endtime must be at or before the last peak endtime # to avoid possibly overlapping peaks new_p["time"] = first_peak["time"] new_p["dt"] = common_dt new_p["length"] = (strax.endtime(last_peak) - new_p["time"]) // common_dt # re-zero relevant part of buffers (overkill? not sure if # this saves much time) bl = last_peak["time"] - first_peak["time"] bl += last_peak["length"] * old_peaks["dt"].max() bl = min(int(bl / common_dt), max_buffer) buffer[:bl] = 0 buffer_top[:bl] = 0 max_data = [] for p in old_peaks: # Upsample the sum and top/bottom array waveforms into their buffers upsample = p["dt"] // common_dt n_after = p["length"] * upsample i0 = (p["time"] - new_p["time"]) // common_dt buffer[i0 : i0 + n_after] = np.repeat(p["data"][: p["length"]], upsample) / upsample buffer_top[i0 : i0 + n_after] = ( np.repeat(p["data_top"][: p["length"]], upsample) / upsample ) # Handle the other peak attributes new_p["area"] += p["area"] new_p["area_per_channel"] += p["area_per_channel"] new_p["n_hits"] += p["n_hits"] new_p["saturated_channel"][p["saturated_channel"] == 1] = 1 max_data.append(p["data"][: p["length"]].max()) # Propagate min/max diff for sub-peaklets, for max diff this # is just an approximation since peaklets can be farther apart. # The value of the individual peaklet is more informative. # For min diff the value is correct. new_p["max_diff"] = max(new_p["max_diff"], p["max_diff"]) new_p["min_diff"] = min(new_p["min_diff"], p["min_diff"]) max_data = np.array(max_data) # Downsample the buffers into # new_p['data'], new_p['data_top'], and new_p['data_start'] strax.store_downsampled_waveform( new_p, buffer, True, True, buffer_top, ) new_p["n_saturated_channels"] = new_p["saturated_channel"].sum() # too lazy to compute these new_p["max_gap"] = -1 new_p["max_goodness_of_split"] = np.nan # since the peaks are sorted, we can just take the first and last channel new_p["first_channel"] = old_peaks[0]["first_channel"] new_p["last_channel"] = old_peaks[-1]["last_channel"] # Use tight_coincidence of the peak with the highest amplitude new_p["tight_coincidence"] = old_peaks["tight_coincidence"][np.argmax(max_data)] # collect the endtime endtime[new_i] = strax.endtime(last_peak) return new_peaks, endtime
[docs] @export def replace_merged(orig, merge): """Return sorted array of 'merge' and members of 'orig' that do not touch any of merge. :param orig: Array of interval-like objects (e.g. peaks) :param merge: Array of interval-like objects (e.g. peaks) """ if not len(merge): return orig skip_windows = strax.touching_windows(orig, merge) skip_n = np.diff(skip_windows, axis=1).sum() result = np.zeros(len(orig) - skip_n + len(merge), dtype=orig.dtype) _replace_merged(result, orig, merge, skip_windows) return result
@numba.njit(cache=True, nogil=True) def _replace_merged(result, orig, merge, skip_windows): result_i = window_i = 0 skip_start, skip_end = skip_windows[0] n_orig = len(orig) n_skipped = 0 for orig_i in range(n_orig): if orig_i == skip_end: result[result_i] = merge[window_i] result_i += 1 window_i += 1 if window_i == len(skip_windows): skip_start = skip_end = n_orig + 100 else: skip_start, skip_end = skip_windows[window_i] if orig_i >= skip_start: n_skipped += 1 continue result[result_i] = orig[orig_i] result_i += 1 if skip_end == n_orig: # Still have to insert the last merged S2 # since orig_i == skip_end is never met assert result_i == len(result) - 1 assert window_i == len(merge) - 1 result[result_i] = merge[window_i] result_i += 1 window_i += 1 assert result_i == len(result) assert window_i == len(skip_windows)
[docs] @export def add_lone_hits( peaks, lone_hits, to_pe, n_top_channels=0, store_data_top=False, store_data_start=False ): """Function which adds information from lone hits to peaks if lone hit is (fully) inside a peak (e.g. after merging.). Modifies peak area and data inplace. :param peaks: Numpy array of peaks :param lone_hits: Numpy array of lone_hits :param to_pe: Gain values to convert lone hit area into PE. :param n_top_channels: Number of top array channels. :param store_data_top: Boolean which indicates whether to store the top array waveform in the peak. :param store_data_start: Boolean which indicates whether to store the first samples of the waveform in the peak. """ _fully_contained_in_sanity(lone_hits, peaks) _add_lone_hits( peaks, lone_hits, to_pe, n_top_channels=n_top_channels, store_data_top=store_data_top, store_data_start=store_data_start, )
@numba.njit(cache=True, nogil=True) def _add_lone_hits( peaks, lone_hits, to_pe, n_top_channels=0, store_data_top=False, store_data_start=False ): """The core function of add_lone_hits.""" fully_contained_index = _fully_contained_in(lone_hits, peaks) for fc_i, lh_i in zip(fully_contained_index, lone_hits): if fc_i == -1: continue p = peaks[fc_i] lh_area = lh_i["area"] * to_pe[lh_i["channel"]] p["area"] += lh_area p["area_per_channel"][lh_i["channel"]] += lh_area # Add lone hit as delta pulse to waveform: index = (lh_i["time"] - p["time"]) // p["dt"] if index < 0 or index > len(p["data"]): raise ValueError("Hit outside of full containment!") p["data"][index] += lh_area if store_data_top: if lh_i["channel"] < n_top_channels: p["data_top"][index] += lh_area if store_data_start: # Non-downsampled waveforms have a fixed minimum dt index_wf_start = (lh_i["time"] - p["time"]) // 10 if index_wf_start < 0: raise ValueError("Hit outside of full containment!") if index_wf_start < len(p["data_start"]): p["data_start"][index_wf_start] += lh_area