Source code for ironflow.gui.workflows.boxes.node_interface.control

# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.
"""
For directly controlling node IO (as opposed to giving it to nodes by connecting them to other OI).
"""

from __future__ import annotations

import pickle
import base64
from typing import TYPE_CHECKING, Callable

import ipywidgets as widgets
import numpy as np
from ryvencore.InfoMsgs import InfoMsgs
from traitlets import TraitError

from ironflow.gui.draws_widgets import DrawsWidgets, draws_widgets
from ironflow.model.node import BatchingNode


if TYPE_CHECKING:
    from ipywidgets import DOMWidget
    from ironflow.gui.workflows.screen import WorkflowsGUI
    from ironflow.model.node import Node


[docs] def deserialize(data): return pickle.loads(base64.b64decode(data))
[docs] class NodeController(DrawsWidgets): """ Handles the creation of widgets for manually adjusting node input and viewing node info. """ main_widget_class = widgets.VBox def __new__(cls, screen: WorkflowsGUI, *args, **kwargs): return super().__new__(cls, *args, **kwargs) def __init__(self, screen: WorkflowsGUI, *args, **kwargs): super().__init__(*args, **kwargs) self.screen = screen self.node = None self._row_height = 30 # px self._border = "1px solid black" self.widget.layout = widgets.Layout( width="50%", border="", max_height="360px", margin="10px", padding="5px", )
[docs] def draw_for_node(self, node: Node | None) -> None: self.clear() self.node = node if self.node is not None: self.draw()
[docs] def update(self): if self.node is not None: self.draw()
@draws_widgets def draw(self) -> None: self.widget.children = [ self._draw_input_box(), self._draw_input_widget(), self._draw_info_box(), ] self.widget.layout.border = self._border return self.widget def _draw_input_widget(self) -> widgets.Widget: try: widget = self.node.input_widget(self.screen, self.node).widget widget.layout.height = "70px" widget.layout.border = "solid 1px blue" return widget except AttributeError: return widgets.Output() def _input_field_list(self) -> list[list[widgets.Widget]]: input = [] if hasattr(self.node, "inputs"): for i_c, inp in enumerate(self.node.inputs[:]): dtype = str(inp.dtype).split(".")[-1] disabled = self.node.block_updates or len(inp.connections) > 0 try: dtype_state = deserialize(inp.data()["dtype state"]) except TypeError: # `inp.data()` winds up calling `serialize` on `inp.get_val()` # This serialization is a pickle dump, which fails with structures (`Atoms`) # Just gloss over it for now dtype_state = { "val": "Serialization error -- please reconnect an input" } if inp.val is None: inp.val = dtype_state["val"] try: if dtype_state["batched"]: inp_widget = widgets.Text( f"Batched {dtype}", continuous_update=False, disabled=True, ) elif dtype == "Integer": inp_widget = widgets.IntText( value=inp.val, description="", continuous_update=False, disabled=disabled, ) elif dtype == "Float": inp_widget = widgets.FloatText( value=inp.val, description="", continuous_update=False, disabled=disabled, ) elif dtype == "Boolean": inp_widget = widgets.Checkbox( value=inp.val, indent=False, description="", disabled=disabled, ) elif dtype == "Choice": inp_widget = widgets.Dropdown( value=inp.val, options=inp.dtype.items, description="", ensure_option=True, disabled=disabled, ) elif dtype == "String" or dtype == "Char": inp_widget = widgets.Text( value=str(inp.val), continuous_update=False, disabled=disabled, ) else: inp_widget = widgets.Text( value=str(inp.val), continuous_update=False, disabled=True ) except TraitError as e: inp_widget = widgets.Label( value="Trait error -- check log and/or change input." ) InfoMsgs.write_err(e) description = inp.label_str if inp.label_str != "" else inp.type_ inp_widget.layout.width = "100px" inp_widget.observe(self._input_change_i(i_c), names="value") batch_button = widgets.ToggleButton( description="Batched", tooltip="Use batches batches of correctly typed data instead of " "instances", layout=widgets.Layout(width="75px"), disabled=inp.dtype is None or not isinstance(inp.node, BatchingNode), value=inp.dtype.batched if hasattr(inp, "dtype") else False, ) batch_button.observe(self._toggle_batching_i(i_c), names="value") reset_button = widgets.Button( tooltip="Reset to default", icon="refresh", layout=widgets.Layout(width="30px"), disabled=len(inp.connections) > 0, ) reset_button.on_click(self._input_reset_i(i_c, inp_widget)) input.append( [widgets.Label(description), inp_widget, batch_button, reset_button] ) return input def _input_change_i(self, i_c) -> Callable: def input_change(change: dict) -> None: # Todo: Test this in exec mode self.node.inputs[i_c].update(change["new"]) # self.node.update(i_c) self.screen.redraw_active_flow_canvas() return input_change def _toggle_batching_i(self, i_c) -> Callable: def toggle_batching(change: dict) -> None: try: InfoMsgs.write( f"Batching for {self.node.title}.{self.node.inputs[i_c].label_str} " f"set to {change['new']}" ) if change["new"]: self.node.inputs[i_c].batch() else: self.node.inputs[i_c].unbatch() self.screen.redraw_active_flow_canvas() self.draw() except AttributeError: pass return toggle_batching def _input_reset_i(self, i_c, associated_input_field) -> Callable: def input_reset(button: widgets.Button) -> None: default = self.node.inputs[i_c].dtype.default self.node.inputs[i_c].update(default) InfoMsgs.write( f"Value for {self.node.title}.{self.node.inputs[i_c].label_str} " f"reset to {default}" ) try: associated_input_field.value = default except TraitError: self.screen.update_node_control() finally: pass self.node.update(i_c) self.screen.redraw_active_flow_canvas() return input_reset def _draw_input_box(self) -> widgets.GridBox | widgets.Output: input_fields = self._input_field_list() n_fields = len(input_fields) if n_fields > 0: return widgets.GridBox( list(np.array(input_fields).flatten()), layout=widgets.Layout( grid_template_columns="auto auto auto auto", grid_auto_rows=f"{self._row_height}px", border="solid 1px blue", height=f"{self._box_height(n_fields)}px", # Automatic height like this really should be doable just with the CSS, # but for the life of me I can't get a CSS solution working right -Liam ), ) else: return widgets.Output() def _box_height(self, n_rows: int) -> int: return n_rows * self._row_height + 8 def _draw_info_box(self) -> widgets.VBox: glob_id_val = None if hasattr(self.node, "GLOBAL_ID"): glob_id_val = self.node.GLOBAL_ID global_id = widgets.Text( value=str(glob_id_val), description="GLOBAL_ID:", disabled=True, layout=widgets.Layout(height=f"{self._row_height}px"), ) title = widgets.Text( value=str(self.node.title), description="Title:", disabled=True, layout=widgets.Layout(height=f"{self._row_height}px"), ) info_box = widgets.VBox( [title, global_id], layout=widgets.Layout( height=f"{self._box_height(2)}px", border="solid 1px red", ), ) return info_box
[docs] def clear(self) -> None: self.node = None self.widget.children = [] self.widget.layout.border = "" super().clear()