from __future__ import annotations
from abc import abstractmethod
import ioiocore as ioc
import numpy as np
from ...common.constants import Constants
from ..core.i_port import IPort
from .node import Node
[docs]
class IONode(ioc.IONode, Node):
"""Abstract base class for input/output nodes in the g.Pype pipeline.
Combines ioiocore.IONode and Node functionality for signal processing
nodes with input and output ports. Handles validation and setup logic
for port contexts. Subclasses must implement the abstract step() method.
"""
[docs]
def __init__(
self,
input_ports: list[ioc.IPort.Configuration] = None,
output_ports: list[ioc.OPort.Configuration] = None,
**kwargs,
):
"""Initialize the IONode with input and output port configurations.
Args:
input_ports: List of input port configurations or None.
output_ports: List of output port configurations or None.
**kwargs: Additional arguments passed to parent classes.
"""
ioc.IONode.__init__(
self, input_ports=input_ports, output_ports=output_ports, **kwargs
)
Node.__init__(self, target=self)
[docs]
def setup(
self, data: dict[str, np.ndarray], port_context_in: dict[str, dict]
) -> dict[str, dict]:
"""Setup the node before processing begins.
Validates input port configurations and creates output port contexts.
Checks for consistent sampling rates, compatible channel counts,
matching frame sizes, and compatible types.
Args:
data: Dictionary mapping port names to numpy arrays.
port_context_in: Dictionary mapping input port names to contexts.
Returns:
Dictionary mapping output port names to context dictionaries.
Raises:
ValueError: If validation fails for any configuration parameter.
"""
# Validate required keys are present in all input port contexts
for context in port_context_in.values():
if Constants.Keys.CHANNEL_COUNT not in context:
raise ValueError("channel_count must be provided in context.")
if Constants.Keys.FRAME_SIZE not in context:
raise ValueError("frame_size must be provided in context.")
if IPort.Configuration.Keys.TIMING not in context:
raise ValueError("timing must be provided in context.")
# Validate sampling rates - all ports must have the same sampling rate
sr_key = Constants.Keys.SAMPLING_RATE
sampling_rates = [
md.get(sr_key, None) for md in port_context_in.values()
]
sampling_rates = [sr for sr in sampling_rates if sr is not None]
if len(set(sampling_rates)) != 1:
raise ValueError("All ports must have the same sampling rate.")
# Validate and normalize channel counts
# Allow broadcasting: single-channel ports can be broadcast to
# multi-channel ports
cc_key = Constants.Keys.CHANNEL_COUNT
channel_counts = [
md.get(cc_key, None) for md in port_context_in.values()
]
channel_counts = [cc for cc in channel_counts if cc is not None]
channel_counts.append(1) # add single channel for comparison
if len(set(channel_counts)) > 2:
# More than 2 unique values means incompatible multi-channel ports
raise ValueError("All ports must have the same channel count.")
# Broadcast single channels to maximum channel count
for md in port_context_in.values():
if md.get(cc_key) is not None:
md[cc_key] = max(channel_counts) # set to max (broadcast)
# Validate frame sizes - all ports must have the same frame size
fsz_key = Constants.Keys.FRAME_SIZE
frame_sizes = [
md.get(fsz_key, None) for md in port_context_in.values()
if md[IPort.Configuration.Keys.TIMING] == Constants.Timing.SYNC
]
frame_sizes = [fsz for fsz in frame_sizes if fsz is not None]
if len(set(frame_sizes)) > 1:
raise ValueError("All ports must have the same frame size.")
# Validate port types - all ports must have compatible types
type_key = IPort.Configuration.Keys.TYPE
types = [md.get(type_key, None) for md in port_context_in.values()]
types = [tp for tp in types if (tp != "Any" and tp is not None)]
if len(set(types)) > 1:
raise ValueError("All ports must have the same type.")
# Build output port contexts by merging input port contexts
port_context_out: dict[str, dict] = {}
op_key = self.Configuration.Keys.OUTPUT_PORTS
name_key = IPort.Configuration.Keys.NAME
context = {}
# Get all unique keys from all input port contexts
all_keys = set().union(*port_context_in.values())
# For each key, determine how to merge values from different ports
for key in all_keys:
values = {}
for port, config in port_context_in.items():
if key in config:
values[port] = config[key]
value_list = list(values.values())
# Case 1: Key exists in only one port
# → use that value directly
if len(values) == 1:
context[key] = value_list[0]
# Case 2: Key exists in multiple ports with identical values
# → use the common value
elif all(value == value_list[0] for value in value_list[1:]):
import copy
context[key] = copy.deepcopy(value_list[0])
# Case 3: Key exists in multiple ports with different values
# → keep per-port mapping
else:
context[key] = values
# Apply the merged context to all output ports
for op in self.config[op_key]:
port_context_out[op[name_key]] = context
return port_context_out
[docs]
@abstractmethod
def step(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""Process data at each discrete time step.
Abstract method that must be implemented by subclasses to define
their specific signal processing behavior.
Args:
data: Dictionary mapping input port names to numpy arrays.
Returns:
Dictionary mapping output port names to numpy arrays.
May return None if no output is produced.
"""
pass # pragma: no cover