Source code for ironflow.model.flow

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from ryvencore import Flow as FlowCore, InfoMsgs
from ryvencore.NodePort import NodePort

from ironflow.model.dtypes import Untyped

if TYPE_CHECKING:
    from ironflow.model.dtypes import DType


[docs] class Flow(FlowCore): """ Wraps `ryvencore.Flow.Flow` to override the connection validity check, so we can add type checking. """
[docs] @staticmethod def batched_or_nothing(dtype: DType) -> str: return "batched " if dtype.batched else ""
@staticmethod def _ports_are_connected(p1: NodePort, p2: NodePort) -> bool: for c in p1.connections: if c in p2.connections: return True return False
[docs] def check_connection_validity(self, p1: NodePort, p2: NodePort) -> bool: """Checks whether a considered connect action is legal""" # ryvencore.Flow.Flow content valid = True if p1.node == p2.node: valid = False if p1.io_pos == p2.io_pos or p1.type_ != p2.type_: valid = False # ironflow content if not self._ports_are_connected(p1, p2): # Only validate connections, not disconnections inp, out = (p1, p2) if p1.io_pos == 1 else (p2, p1) if ( isinstance(inp.dtype, Untyped) or isinstance(out.dtype, Untyped) or ( inp.dtype.batched != out.dtype.batched and isinstance(out.val, (list, np.ndarray)) ) ): type_valid = inp.dtype.accepts(out.val) check_type = "value" else: type_valid = inp.dtype.accepts(out.dtype) check_type = "dtype" InfoMsgs.write( f"{inp.node.title}.{inp.label_str} input " f"{self.batched_or_nothing(inp.dtype)}{inp.dtype.__class__.__name__} " f"made a {check_type} check to receive " f"{out.node.title}.{out.label_str} output " f"{self.batched_or_nothing(out.dtype)}{out.dtype.__class__.__name__} " f"and returned {type_valid}" ) valid = valid and type_valid # ryvencore.Flow.Flow content self.connection_request_valid.emit(valid) return valid