Source code for strax.mailbox

# pylint: disable=redefined-builtin
from concurrent.futures import Future, TimeoutError
import heapq
import sys
import threading
import typing
import logging

from strax.utils import exporter

export, __all__ = exporter()


[docs] @export class MailboxException(Exception): pass
[docs] @export class MailboxReadTimeout(MailboxException): pass
[docs] @export class MailboxFullTimeout(MailboxException): pass
[docs] @export class InvalidMessageNumber(MailboxException): pass
[docs] @export class MailBoxAlreadyClosed(MailboxException): pass
[docs] @export class MailboxKilled(MailboxException): pass
[docs] @export class Mailbox: """Publish/subscribe mailbox for builing complex pipelines out of simple iterators, using multithreading. A sender can be any iterable. To read from the mailbox, either: 1. Use .subscribe() to get an iterator. You can only use one of these per thread. 2. Use .add_subscriber(f) to subscribe the function f. f should take an iterator as its first argument (and actually iterate over it, of course). Each sender and receiver is wrapped in a thread, so they can be paused: - senders, if the mailbox is full; - readers, if they call next() but the next message is not yet available. Any futures sent in are awaited before they are passed to receivers. Exceptions in a sender cause MailboxKilled to be raised in each reader. If the reader doesn't catch this, and it writes to another mailbox, this therefore kills that mailbox (raises MailboxKilled for each reader) as well. Thus MailboxKilled exceptions travel downstream in pipelines. Sender threads are not killed by exceptions raise in readers. To kill sender threads too, use .kill(upstream=True). Even this does not propagate further upstream than the immediate sender threads. """ # In strax, these are overriden by context options # 'timeout' and 'max_messages'. They are here only to support # creating mailboxes directly without strax. DEFAULT_TIMEOUT = 300 DEFAULT_MAX_MESSAGES = 4 def __init__(self, name="mailbox", timeout=None, lazy=False, max_messages=None): self.name = name if timeout is None: timeout = self.DEFAULT_TIMEOUT self.timeout = timeout if max_messages is None: max_messages = self.DEFAULT_MAX_MESSAGES self.max_messages = max_messages self.lazy = lazy if self.lazy: self.max_messages = float("inf") self.closed = False self.force_killed = False self.killed = False self.killed_because = None self._mailbox = [] self._subscribers_have_read = [] self._subscriber_waiting_for = [] self._subscriber_can_drive = [] self._n_sent = 0 self._threads = [] self._lock = threading.RLock() self.log = logging.getLogger(self.name) # Conditions to wait on # Do NOT call notify_all when the condition is False! # We use wait_for, which also returns False when the timeout is broken # (Is this an odd design decision in the standard library # or am I misunderstanding something?) class Condition: """Small helper class which wraps "threading.Condition" to get some useful logging information for debugging.""" def __init__(self, name, log, lock): self.log = log self._lock = lock self.name = name self.log.debug(f'Initialize "{name}" with lock state: {lock}.') self.threading_condition = threading.Condition(lock=lock) def notify_all(self): self.log.debug(f"Notifying all for {self.name} with lock state: {self._lock}") self.threading_condition.notify_all() def wait_for(self, *args, **kwargs): self.log.debug( f'Waiting for a change in "{self.name}" with state {args[0]} and lock ' f"state: {self._lock}" ) return self.threading_condition.wait_for(*args, **kwargs) # If you're waiting to read a new message that hasn't yet arrived: self._read_condition = Condition("_read_condition", self.log, lock=self._lock) # If you're waiting to write a new message because the mailbox is full self._write_condition = Condition("_write_condition", self.log, lock=self._lock) # If you're waiting to fetch a new element because the subscribers # stil have other things to do self._fetch_new_condition = Condition("_fetch_new_condition", self.log, lock=self._lock) self.log.debug("Initialized")
[docs] def add_sender(self, source, name=None): """Configure mailbox to read from an iterable source. :param source: Iterable to read from :param name: Name of the thread in which the function will run. Defaults to source:<mailbox_name> """ if name is None: name = f"source:{self.name}" t = threading.Thread(target=self._send_from, name=name, args=(source,)) self._threads.append(t)
[docs] def add_reader(self, subscriber, name=None, can_drive=True, **kwargs): """Subscribe a function to the mailbox. :param subscriber: Function which accepts a generator over messages as the first argument. Any kwargs will also be passed to the function. :param name: Name of the thread in which the function will run. Defaults to read_<number>:<mailbox_name> :param can_drive: Whether this reader can cause new messages to be generated when in lazy mode. """ if name is None: name = f"read_{self._n_subscribers}:{self.name}" t = threading.Thread( target=subscriber, name=name, args=(self.subscribe(can_drive=can_drive),), kwargs=kwargs ) self._threads.append(t)
[docs] def subscribe(self, can_drive=True): """Return generator over messages in the mailbox.""" with self._lock: subscriber_i = self._n_subscribers self._subscriber_can_drive.append(can_drive) self._subscribers_have_read.append(-1) self._subscriber_waiting_for.append(None) self.log.debug(f"Added subscriber {subscriber_i}") return self._read(subscriber_i=subscriber_i)
[docs] def start(self): if not self._n_subscribers: raise ValueError(f"Attempt to start mailbox {self.name} without subscribers") for t in self._threads: t.start()
[docs] def kill_from_exception(self, e, reraise=True): """Kill the mailbox following a caught exception e.""" if isinstance(e, MailboxKilled): # Kill this mailbox too. self.log.debug("Propagating MailboxKilled exception") self.kill(reason=e.args[0]) # Do NOT raise! One traceback on the screen is enough. else: self.log.debug(f"Killing mailbox due to exception {e}!") self.kill(reason=(e.__class__, e, sys.exc_info()[2])) if reraise: raise e
[docs] def kill(self, upstream=True, reason=None): with self._lock: self.log.debug(f"Kill received by {self.name}") if upstream: self.force_killed = True if self.killed: self.log.debug(f"Double kill on {self.name} = NOP") return self.killed = True self.killed_because = reason self._read_condition.notify_all() self._write_condition.notify_all() self._fetch_new_condition.notify_all()
[docs] def cleanup(self): for t in self._threads: t.join(timeout=self.timeout) if t.is_alive(): raise RuntimeError("Thread %s did not terminate!" % t.name)
def _can_fetch(self): """Return if we can fetch then send the next element from the source. If not, it returns None (to distinguish from False, which means the timeout was broken) """ assert self.lazy # The .send() knows how to handle the exception properly # (if we raise here we will likely duplicate the exception) if self.killed: return True # If someone is still waiting for a message we already have # (so they just haven't woken up yet), don't fetch a new message. if len(self._mailbox) and any( [x is not None and x <= self._lowest_msg_number for x in self._subscriber_waiting_for] ): return False # Everyone is waiting for the new chunk or not at all. # Fetch only if a driver is waiting. for _i, waiting_for in enumerate(self._subscriber_waiting_for): if self._subscriber_can_drive[_i] and waiting_for is not None: return True return False def _send_from(self, iterable): """Send to mailbox from iterable, exiting appropriately if an exception is thrown.""" try: i = 0 while True: if self.lazy: with self._lock: if not self._can_fetch(): self.log.debug( f"Waiting to fetch {i}, " f"{self._subscriber_waiting_for}, " f"{self._subscriber_can_drive}" ) if not self._fetch_new_condition.wait_for( self._can_fetch, timeout=self.timeout ): raise MailboxReadTimeout( f"{self} could not progress beyond {i}, " "no driving subscriber requested it." ) try: x = next(iterable) except StopIteration: # No need to send this yet, close will do that break try: self.send(x) except Exception as e: # Inform the source we're going down iterable.throw(e) raise i += 1 except Exception as e: self.kill_from_exception(e) else: self.log.debug("Producing iterable exhausted, regular stop") self.close()
[docs] def send(self, msg, msg_number: typing.Union[int, None] = None): """Send a message. If the message is a future, receivers will be passed its result. (possibly waiting for completion if needed) If the mailbox is currently full, sleep until there is room for your message (or timeout occurs) """ with self._lock: if self.closed: raise MailBoxAlreadyClosed(f"Can't send to closed {self.name}") if self.force_killed: self.log.debug(f"Sender found {self.name} force-killed") raise MailboxKilled(self.killed_because) if self.killed: self.log.debug("Send to killed mailbox: message lost") return # We accept int numbers or anything which equals to it's int(...) # (like numpy integers) if msg_number is None: msg_number = self._n_sent try: int(msg_number) assert msg_number == int(msg_number) except (ValueError, AssertionError): raise InvalidMessageNumber("Msg numbers must be integers") read_until = min(self._subscribers_have_read, default=-1) if msg_number <= read_until: raise InvalidMessageNumber( f"Attempt to send message {msg_number} while " f"subscribers already read {read_until}." ) def can_write(): return len(self._mailbox) < self.max_messages or self.killed if not can_write(): self.log.debug("Subscribers have read: " + str(self._subscribers_have_read)) self.log.debug(f"Mailbox full, wait to send {msg_number}") if not self._write_condition.wait_for(can_write, timeout=self.timeout): raise MailboxFullTimeout(f"Mailbox buffer for {self.name} emptied too slow.") if self.killed: self.log.debug( f"Sender found {self.name} killed while waiting for room for new messages." ) if self.force_killed: raise MailboxKilled(self.killed_because) return heapq.heappush(self._mailbox, (msg_number, msg)) self.log.debug(f"Sent {msg_number}") self._n_sent += 1 self._read_condition.notify_all()
[docs] def close(self): self.log.debug(f"Closing; sending StopIteration") with self._lock: self.send(StopIteration) self.closed = True self.log.debug(f"Closed to incoming messages")
def _read(self, subscriber_i): """Iterate over incoming messages in order. Your thread will sleep until the next message is available, or timeout expires (in which case MailboxReadTimeout is raised) """ self.log.debug("Start reading") next_number = 0 last_message = False while not last_message: with self._lock: # Wait for new messages def next_ready(): return self._has_msg(next_number) or self.killed if not next_ready(): self.log.debug(f"Checking/waiting for {next_number}") self._subscriber_waiting_for[subscriber_i] = next_number if self.lazy and self._can_fetch(): self._fetch_new_condition.notify_all() if not self._read_condition.wait_for(next_ready, self.timeout): raise MailboxReadTimeout(f"{self.name} did not get {next_number} in time.") self._subscriber_waiting_for[subscriber_i] = None if self.killed: self.log.debug(f"Reader finds {self.name} killed") raise MailboxKilled(self.killed_because) # Grab all messages we can yield to_yield = [] while self._has_msg(next_number): msg = self._get_msg(next_number) if msg is StopIteration: self.log.debug(f"{next_number} is StopIteration") last_message = True to_yield.append((next_number, msg)) next_number += 1 if len(to_yield) > 1: self.log.debug( f"Read {to_yield[0][0]}-{to_yield[-1][0]} in subscriber {subscriber_i}" ) else: self.log.debug(f"Read {to_yield[0][0]} in subscriber {subscriber_i}") self._subscribers_have_read[subscriber_i] = next_number - 1 # Clean up the mailbox while len(self._mailbox) and ( min(self._subscribers_have_read) >= self._lowest_msg_number ): heapq.heappop(self._mailbox) if self.lazy and self._can_fetch(): self._fetch_new_condition.notify_all() self._write_condition.notify_all() for msg_number, msg in to_yield: if msg is StopIteration: break elif isinstance(msg, Future): if not msg.done(): self.log.debug(f"Waiting for future {msg_number}") try: res = msg.result(timeout=self.timeout) except TimeoutError: raise TimeoutError(f"Future {msg_number} timed out!") self.log.debug(f"Future {msg_number} completed") else: res = msg.result() self.log.debug(f"Future {msg_number} was already done") else: res = msg try: yield res except Exception as e: # TODO: Should I also handle timeout errors like this? self.kill_from_exception(e) self.log.debug("Done reading") def __repr__(self): return f"<{self.__class__.__name__}: {self.name}>" def _get_msg(self, number): for msg_number, msg in self._mailbox: if msg_number == number: return msg raise RuntimeError(f"Could not find message {number}") def _has_msg(self, number): """Return if mailbox has message number. Also returns True if mailbox is killed, so be sure to check self.killed after this! """ if self.killed: return True return any([msg_number == number for msg_number, _ in self._mailbox]) @property def _n_subscribers(self): return len(self._subscribers_have_read) @property def _lowest_msg_number(self): return self._mailbox[0][0]
[docs] @export def divide_outputs( source, mailboxes: typing.Dict[str, Mailbox], lazy=False, flow_freely=tuple(), outputs=None ): """This code is a 'mail sorter' which gets dicts of arrays from source and sends the right array to the right mailbox.""" # raise ZeroDivisionError # TODO: check this is handled properly if outputs is None: outputs = mailboxes.keys() mbs_to_kill = [mailboxes[d] for d in outputs] # TODO: this code duplicates exception handling and cleanup # from Mailbox.send_from! Can we avoid that somehow? i = 0 try: while True: for d in outputs: m = mailboxes[d] if d in flow_freely: # Do not block on account of these guys m.log.debug(f"Not locking {d}") continue if lazy: with m._lock: if not m._can_fetch(): m.log.debug( f"Waiting to fetch {i}, " f"{m._subscriber_waiting_for}, " f"{m._subscriber_can_drive}" ) if not m._fetch_new_condition.wait_for(m._can_fetch, timeout=m.timeout): raise MailboxReadTimeout( f"{m} could not progress beyond {i}, " "no driving subscriber requested it." ) try: result = next(source) except StopIteration: # No need to send this yet, close will do that break try: for d, x in result.items(): mailboxes[d].send(x) except Exception as e: # Inform the source we're going down source.throw(e) raise i += 1 except Exception as e: for m in mbs_to_kill: m.kill_from_exception(e, reraise=False) if not isinstance(e, MailboxKilled): raise else: for m in mbs_to_kill: m.close()