from __future__ import annotations
import copy
from collections import deque
from typing import Union
import ioiocore as ioc
import numpy as np
from ...common.constants import Constants
from ..core.i_port import IPort
from ..core.io_node import IONode
from ..core.o_port import OPort
[docs]
class Router(IONode):
"""Channel routing and selection node for flexible data flow management.
Provides channel selection and routing capabilities for BCI data pipelines.
Allows selecting specific channels from multiple input ports and routing
them to multiple output ports. Supports both simple channel selection
(index lists) and complex multi-port routing (dictionary mapping).
"""
#: Special constant for selecting all available channels
ALL: list = [-1]
# Type annotation for the internal routing map
_map: dict
# Type annotation for output channel counts
_channel_count_out: dict
# Type annotation for SYNC/ASYNC port tracking
_sync_ports: set
_async_ports: set
# Type annotation for ASYNC data buffers (queues)
_async_buffers: dict[str, deque]
# Type annotation for last known ASYNC values (used when queue is empty)
_async_last_values: dict[str, np.ndarray]
[docs]
class Configuration(ioc.IONode.Configuration):
"""Configuration class for Router parameters."""
[docs]
class Keys(ioc.IONode.Configuration.Keys):
"""Configuration key constants for the Router."""
#: Input channels configuration key
INPUT_CHANNELS = "input_channels"
#: Output channels configuration key
OUTPUT_CHANNELS = "output_channels"
[docs]
def __init__(
self,
input_channels: Union[list, dict] = None,
output_channels: Union[list, dict] = None,
**kwargs,
):
"""Initialize the Router node with channel selection configurations.
Args:
input_channels: Specification for input channel selection. Can be
None (all channels), list (channel indices), or dict (port
name to channel indices mapping).
output_channels: Specification for output channel selection.
Same format as input_channels.
**kwargs: Additional configuration parameters passed to IONode.
Raises:
ValueError: If input_channels or output_channels is empty.
"""
# Set default input channels to all channels on default port
if input_channels is None:
input_channels = {Constants.Defaults.PORT_IN: Router.ALL}
# Convert list format to dictionary format for input channels
if type(input_channels) is list:
if len(input_channels) == 0:
raise ValueError("input_channels must not be empty.")
# Convert single list to list of lists if needed
if type(input_channels[0]) is not list:
input_channels = [input_channels]
# Create port mappings
if len(input_channels) == 1:
input_channels = {
Constants.Defaults.PORT_IN: input_channels[0]
}
else:
input_channels = {
f"in{i + 1}": val for i, val in enumerate(input_channels)
}
# Create input port configurations
input_ports = [
IPort.Configuration(name=name, timing=Constants.Timing.INHERITED)
for name in input_channels.keys()
]
input_ports = kwargs.pop(
Router.Configuration.Keys.INPUT_PORTS, input_ports)
# Set default output channels to all channels on default port
if output_channels is None:
output_channels = {Constants.Defaults.PORT_OUT: Router.ALL}
# Convert list format to dictionary format for output channels
if type(output_channels) is list:
if len(output_channels) == 0:
raise ValueError("output_channels must not be empty.")
# Convert single list to list of lists if needed
if type(output_channels[0]) is not list:
output_channels = [output_channels]
# Create port mappings
if len(output_channels) == 1:
output_channels = {
Constants.Defaults.PORT_OUT: output_channels[0]
}
else:
output_channels = {
f"out{i + 1}": val for i, val in enumerate(output_channels)
}
# Create output port configurations
output_ports = [
OPort.Configuration(name=name) for name in output_channels.keys()
]
output_ports = kwargs.pop(
Router.Configuration.Keys.OUTPUT_PORTS, output_ports)
# Initialize internal routing map
self._map = {}
# Store output channel counts for use in step()
self._channel_count_out: dict = {}
# Initialize SYNC/ASYNC port tracking
self._sync_ports: set = set()
self._async_ports: set = set()
# Initialize ASYNC data buffers (queues)
self._async_buffers: dict[str, deque] = {}
# Initialize last known ASYNC values
self._async_last_values: dict[str, np.ndarray] = {}
# Initialize parent IONode with all configurations
IONode.__init__(
self,
input_channels=input_channels,
output_channels=output_channels,
input_ports=input_ports,
output_ports=output_ports,
**kwargs,
)
[docs]
def setup(
self, data: dict[str, np.ndarray], port_context_in: dict[str, dict]
) -> dict[str, dict]:
"""Set up the Router node and build the internal channel mapping.
Creates the internal routing map that defines which input channels
are connected to which output channels. Validates that all input
ports have compatible sampling rates, frame sizes, and data types.
Args:
data: Initial data dictionary for port configuration.
port_context_in: Input port context information with channel
counts, sampling rates, frame sizes, and data types.
Returns:
Output port context with routing information and updated
channel counts for each output port.
Raises:
ValueError: If input ports have incompatible parameters.
"""
# Get commonly used configuration keys
cc_key = Constants.Keys.CHANNEL_COUNT
name_key = IPort.Configuration.Keys.NAME
# Build input channel mapping
input_map: list = []
ip_key = Router.Configuration.Keys.INPUT_PORTS
ic_key = Router.Configuration.Keys.INPUT_CHANNELS
# Process each input port to build channel mapping
for k in range(len(self.config[ip_key])):
port = self.config[ip_key][k]
name = port[name_key]
sel = self.config[ic_key][name]
# Expand "ALL" to actual channel range
if sel == Router.ALL:
sel = range(port_context_in[name][cc_key])
# Add each selected channel to the input map
input_map.extend([{name: [n]} for n in sel])
# Build output port mapping using input map
op_key = Router.Configuration.Keys.OUTPUT_PORTS
oc_key = Router.Configuration.Keys.OUTPUT_CHANNELS
for k in range(len(self.config[op_key])):
port = self.config[op_key][k]
name = port[name_key]
sel = self.config[oc_key][name]
# Expand "ALL" to full input map range
if sel == Router.ALL:
sel = range(len(input_map))
# Map selected input channels to this output port
map = [input_map[n] for n in sel]
map_grouped = []
from itertools import groupby
for key, group in groupby(map, key=lambda x: list(x.keys())[0]):
values = []
for item in group:
values.extend(item[key])
map_grouped.append({key: values})
self._map[name] = map_grouped
# Identify SYNC vs ASYNC ports and initialize buffers
timing_key = IPort.Configuration.Keys.TIMING
self._sync_ports = set()
self._async_ports = set()
self._async_buffers = {}
for port_name, context in port_context_in.items():
port_timing = context.get(timing_key, Constants.Timing.SYNC)
if port_timing == Constants.Timing.ASYNC:
self._async_ports.add(port_name)
# Initialize buffer queue and last value with zeros
cc = context.get(Constants.Keys.CHANNEL_COUNT, 1)
self._async_buffers[port_name] = deque()
self._async_last_values[port_name] = np.zeros(
(1, cc), dtype=Constants.DATA_TYPE
)
else:
self._sync_ports.add(port_name)
# Validate sampling rate consistency across SYNC input ports only
sr_key = Constants.Keys.SAMPLING_RATE
sampling_rates = [
md.get(sr_key, None)
for name, md in port_context_in.items()
if name in self._sync_ports
]
sampling_rates = [sr for sr in sampling_rates if sr is not None]
if len(set(sampling_rates)) > 1:
err_msg = "All SYNC ports must have the same sampling rate."
raise ValueError(err_msg)
sr = sampling_rates[0] if sampling_rates else None
# Validate frame size consistency across SYNC input ports only
fsz_key = Constants.Keys.FRAME_SIZE
frame_sizes = [
md.get(fsz_key, None)
for name, md in port_context_in.items()
if name in self._sync_ports
]
frame_sizes = [fsz for fsz in frame_sizes if fsz is not None]
if len(set(frame_sizes)) > 1:
raise ValueError("All SYNC ports must have the same frame size.")
fsz = frame_sizes[0] if frame_sizes else 1
# Validate data type consistency across all input ports
type_key = IPort.Configuration.Keys.TYPE
types = [md.get(type_key, None) for md in port_context_in.values()]
types = [tp for tp in types if (tp != "Any" and tp is not None)]
if len(set(types)) > 1:
raise ValueError("All ports must have the same type.")
tp = types[0] if len(types) > 0 else "Any"
# Build output port context information
port_context_out: dict[str, dict] = {}
cc_key = Constants.Keys.CHANNEL_COUNT
op_key = self.Configuration.Keys.OUTPUT_PORTS
name_key = OPort.Configuration.Keys.NAME
timing_key = OPort.Configuration.Keys.TIMING
for op in self.config[op_key]:
context = {}
# Get all input ports referenced by this output
in_ports = {key for d in self._map[op[name_key]] for key in d}
# Copy context from input ports with unique naming
for key1 in in_ports:
full_key = self.name + "_" + key1
context[full_key] = {}
for key2 in port_context_in[key1]:
# Skip ID and NAME keys from input context
if key2 in [
IPort.Configuration.Keys.ID,
IPort.Configuration.Keys.NAME,
]:
continue
# Deep copy to prevent reference leakage
context[full_key][key2] = port_context_in[key1][key2]
# Set output port context with validated values
context[cc_key] = sum(
len(d[key]) for d in self._map[op[name_key]] for key in d
)
context[sr_key] = sr
context[fsz_key] = fsz
context[type_key] = tp
context[timing_key] = op[timing_key]
port_context_out[op[name_key]] = context
self._channel_count_out[op[name_key]] = context[cc_key]
return port_context_out
[docs]
def step(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""Process one frame of data by routing channels according to mapping.
Routes channels from input ports to output ports based on the channel
mapping established during setup. Each output port receives data from
its configured subset of input channels.
For mixed SYNC/ASYNC inputs:
- ASYNC data is buffered when it arrives
- Output is only produced when SYNC data is available
- Buffered ASYNC data is used when SYNC data triggers output
Args:
data: Dictionary containing input data arrays for each port.
Keys are port names, values are arrays with shape
(frame_size, channel_count).
Returns:
Dictionary containing output data arrays for each output port.
Keys are output port names, values are arrays with shape
(frame_size, selected_channel_count). Returns empty dict if
no SYNC data is available.
"""
# Enqueue any new ASYNC data
for port_name in self._async_ports:
if port_name in data and data[port_name] is not None:
self._async_buffers[port_name].append(data[port_name])
# Check if any SYNC data is available
sync_data_available = all(
port_name in data and data[port_name] is not None
for port_name in self._sync_ports
)
# Only produce output when SYNC data is available
# (or if there are no SYNC ports at all)
if not sync_data_available and len(self._sync_ports) > 0:
return {}
# Build merged data dict: SYNC data + dequeued ASYNC data
merged_data = {}
for port_name in data:
if port_name in self._sync_ports:
merged_data[port_name] = data[port_name]
# Dequeue next ASYNC sample for each ASYNC port (or use last value)
for port_name in self._async_ports:
if self._async_buffers[port_name]:
# Consume next sample from queue and update last value
self._async_last_values[port_name] = (
self._async_buffers[port_name].popleft()
)
merged_data[port_name] = self._async_last_values[port_name]
data_out: dict = {}
# Process each output port mapping
for port_out, mapping in self._map.items():
# Calculate total output channels
channel_count = self._channel_count_out[port_out]
# Get frame size from first available SYNC input, or default to 1
frame_size = 1
for port_name in self._sync_ports:
if (
port_name in merged_data
and merged_data[port_name] is not None
):
frame_size = merged_data[port_name].shape[0]
break
# Pre-allocate output array
output_array = np.zeros(
(frame_size, channel_count), dtype=Constants.DATA_TYPE
)
# Fill output array using direct slicing
col_idx = 0
for m in mapping:
for port_in, ch_in in m.items():
try:
num_ch = len(ch_in)
port_data = merged_data.get(port_in)
if port_data is not None:
# Broadcast ASYNC data (1 row) to match frame_size
if port_data.shape[0] == 1 and frame_size > 1:
output_array[:, col_idx:col_idx + num_ch] = (
np.broadcast_to(
port_data[:, ch_in],
(frame_size, num_ch)
)
)
else:
output_array[:, col_idx:col_idx + num_ch] = (
port_data[:, ch_in]
)
col_idx += num_ch
except Exception:
# Handle missing data with zeros
col_idx += len(ch_in)
data_out[port_out] = output_array
return data_out