Source code for ironflow.nodes.pyiron.atomistics_nodes

# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.
"""
Ryven nodes specifc to pyiron (or with ironflow improvements like an ipywidgets representation).
"""

from __future__ import annotations

import json
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import matplotlib.pylab as plt
import numpy as np

from pyiron_atomistics import Project
from pyiron_atomistics.atomistics.structure.factory import StructureFactory
from pyiron_atomistics.atomistics.job.atomistic import AtomisticGenericJob
from pyiron_atomistics.lammps import list_potentials

from ironflow.gui.canvas_widgets import ButtonNodeWidget
from ironflow.model import dtypes, NodeInputBP, NodeOutputBP
from ironflow.model.node import Node
from ironflow.nodes.std.special_nodes import DualNodeBase

if TYPE_CHECKING:
    from pyiron_base import HasGroups

STRUCTURE_FACTORY = StructureFactory()


[docs]class BeautifulHasGroups: """ A helper class for giving classes that inherit from `pyiron_base.HasGroups` a more appealing representation in ipywidgets. """ def __init__(self, has_groups: HasGroups | None): self._has_groups = has_groups
[docs] def to_builtin(self, has_groups=None): has_groups = has_groups if has_groups is not None else self._has_groups if has_groups is not None: repr_dict = {} for k in has_groups.list_groups(): repr_dict[k] = self.to_builtin(has_groups[k]) for k in has_groups.list_nodes(): repr_dict[k] = str(has_groups[k]) return repr_dict else: return None
def _repr_json_(self): return self.to_builtin() def _repr_html_(self): name = self._has_groups.__class__.__name__ plain = f"{name}({json.dumps(self.to_builtin(), indent=2, default=str)})" return "<pre>" + plain + "</pre>"
[docs]class Project_Node(Node): """ Create a pyiron project. Inputs: name (str): The name of the project. Will access existing project data under that name. (Default is ".".) Outputs: project (pyiron_atomistics.Project): The project object. """ # this __doc__ string will be displayed as tooltip in the editor title = "Project" init_inputs = [ NodeInputBP(dtype=dtypes.Char(default="."), label="name"), ] init_outputs = [ NodeOutputBP(label="project"), ] color = "#aabb44"
[docs] def place_event(self): super().place_event() self.update()
[docs] def update_event(self, inp=-1): pr = Project(self.input(0)) self.set_output_val(0, pr)
@property def _project(self): return self.output(0) @property def representations(self) -> dict: return { "name": str(self.input(0)), # "job_table": self._project.job_table() if self._project is not None else None # Todo: Figure out how to display this without breaking the gui size; right now it automatically grows # the gui because the table is so wide. }
[docs]class OutputsOnlyAtoms(Node, ABC): """ A helper class that manages representations for nodes whose only output is a `pyiron_atomistics.Atoms` object. Outputs: structure (pyiron_atomistics.Atoms): An atomic structure. """ init_outputs = [ NodeOutputBP(label="structure"), ] color = "#aabb44"
[docs] @abstractmethod def update_event(self, inp=-1): """Must set output 0 to an instance of pyiron_atomistics.atomistics.atoms.Atoms""" pass
@property def representations(self) -> dict: return {"plot3d": self.output(0).plot3d(), "print": self.output(0)}
[docs]class BulkStructure_Node(OutputsOnlyAtoms): """ Generate a bulk atomic structure. Inputs: element (str): The atomic symbol for the desired atoms. (Default is "Fe".) crystal_structure (str | None): Must be one of sc, fcc, bcc, hcp, diamond, zincblende, rocksalt, cesiumchloride, fluorite or wurtzite. a (float | None): Lattice constant. c (float | None): Lattice constant. c_over_a (float | None): c/a ratio used for hcp. Default is ideal ratio: sqrt(8/3). u (float | None): Internal coordinate for Wurtzite structure. orthorhombic (bool): Construct orthorhombic unit cell instead of primitive cell. (Takes precedence over cubic flag when both are true.) cubic (bool): Construct cubic unit cell if possible. Outputs: structure (pyiron_atomistics.Atoms): A mono-species bulk structure. """ # this __doc__ string will be displayed as tooltip in the editor title = "BulkStructure" init_inputs = [ NodeInputBP(dtype=dtypes.Char(default="Fe"), label="element"), NodeInputBP( dtype=dtypes.Choice( default=None, items=[ None, "sc", "fcc", "bcc", "hcp", "diamond", "zincblende", "rocksalt", "cesiumchloride", "fluorite", "wurtzite", ], ), label="crystal_structure", ), NodeInputBP(dtype=dtypes.Float(default=None), label="a"), NodeInputBP(dtype=dtypes.Float(default=None), label="c"), NodeInputBP(dtype=dtypes.Float(default=None), label="c_over_a"), NodeInputBP(dtype=dtypes.Float(default=None), label="u"), NodeInputBP(dtype=dtypes.Boolean(default=False), label="orthorhombic"), NodeInputBP(dtype=dtypes.Boolean(default=False), label="cubic"), ]
[docs] def update_event(self, inp=-1): self.set_output_val( 0, STRUCTURE_FACTORY.bulk( self.input(0), crystalstructure=self.input(1), a=self.input(2), c=self.input(3), covera=self.input(4), u=self.input(5), orthorhombic=self.input(6), cubic=self.input(7), ), )
[docs] def place_event(self): super().place_event() self.update()
[docs]class Repeat_Node(OutputsOnlyAtoms): """ Repeat atomic structure supercell. Inputs: structure (pyiron_atomistics.Atoms): The structure to repeat periodically. all (int): The number of times to repeat it in each of the three bravais lattice directions. Outputs: structure (pyiron_atomistics.Atoms): A repeated copy of the input structure. """ # this __doc__ string will be displayed as tooltip in the editor title = "Repeat" init_inputs = [ NodeInputBP(dtype=dtypes.Data(size="m"), label="structure"), NodeInputBP(dtype=dtypes.Integer(default=1, bounds=(1, 100)), label="all"), ]
[docs] def update_event(self, inp=-1): self.set_output_val(0, self.input(0).repeat(self.input(1)))
[docs]class ApplyStrain_Node(OutputsOnlyAtoms): """ Apply strain on atomic structure supercell. Inputs: structure (pyiron_atomistics.Atoms): The atomic structure to strain. strain (float): The isotropic strain to apply, where 0 is unstrained. (Default is 0.) Outputs: (pyiron_atomistics.Atoms): A strained copy of the input structure. """ title = "ApplyStrain" init_inputs = [ NodeInputBP(dtype=dtypes.Data(size="m"), label="structure"), NodeInputBP(dtype=dtypes.Float(default=0, bounds=(-100, 100)), label="strain"), ]
[docs] def update_event(self, inp=-1): self.set_output_val( 0, self.input(0).apply_strain(float(self.input(1)), return_box=True) )
[docs]class Lammps_Node(Node): """ WIP. """ title = "Lammps" version = "v0.1" init_inputs = [ NodeInputBP(type_="exec", label="run"), NodeInputBP(type_="exec", label="remove"), NodeInputBP(dtype=dtypes.Data(size="m"), label="project"), NodeInputBP(dtype=dtypes.Char(default="job"), label="name"), NodeInputBP(dtype=dtypes.Data(size="m"), label="structure"), NodeInputBP( dtype=dtypes.Choice( default="Set structure first", items=["Set structure first"] ), label="potential", ), ] init_outputs = [ NodeOutputBP(type_="exec"), NodeOutputBP(label="job"), ] color = "#5d95de" @property def _project(self): return self.input(2) @property def _name(self): return self.input(3) @property def _structure(self): return self.input(4) @property def _potential(self): return self.input(5) def _run(self): job = self._project.create.job.Lammps(self._name) job.structure = self._structure job.potential = self._potential self._job = job job.run() self.set_output_val(1, job) self.exec_output(0) def _remove(self): try: name = ( self._job.name ) # Remove based on the run job, not the input name which might have changed... self._project.remove_job(name) self.set_output_val(1, None) except AttributeError: pass def _update_potential_choices(self): potl_input = self.inputs[5] last_potential = potl_input.val structure = self.inputs[4].val available_potentials = list_potentials(structure) if len(available_potentials) == 0: potl_input.val = "No valid potential" potl_input.dtype.items = ["No valid potential"] else: if last_potential not in available_potentials: potl_input.val = available_potentials[0] potl_input.dtype.items = available_potentials
[docs] def update_event(self, inp=-1): if inp == 0: self._run() elif inp == 1: self._remove() elif inp == 4: self._update_potential_choices()
@property def representations(self) -> dict: return {"job": BeautifulHasGroups(self.output(1))}
[docs]class GenericOutput_Node(Node): """ Select Generic Output item. Inputs: job (AtomisticGenericJob): A job with an `output` attribute of type `pyiron_atomistics.atomistics.job.atomistic.GenericOutput`. field (dtypes.Choice): Which output field to look at. Automatically populates once the job is valid. Outputs: output (numpy.ndarray): The selected output field. """ version = "v0.1" title = "GenericOutput" init_inputs = [ NodeInputBP(dtype=dtypes.Data(size="m"), label="job"), NodeInputBP( dtype=dtypes.Choice( default="Input an atomistic job", items=["Input an atomistic job"] ), label="field", ), ] init_outputs = [ NodeOutputBP(label="output"), ] color = "#c69a15" def __init__(self, params): super().__init__(params) @property def _job(self): return self.input(0) def _update_fields(self): if isinstance(self._job, AtomisticGenericJob): self.inputs[1].dtype.items = self._job["output/generic"].list_nodes() self.inputs[1].val = self.inputs[1].dtype.items[0] else: self.inputs[1].dtype.items = [self.init_inputs[1].dtype.default] # Note: It would be sensible to use `self.init_outputs[1].dtype.items` above, but this field gets updated # to `self.inputs[1].dtype.items`, probably because of the mutability of lists. self.inputs[1].val = self.init_inputs[1].dtype.default def _update_value(self): if isinstance(self._job, AtomisticGenericJob): val = self._job[f"output/generic/{self.input(1)}"] else: val = None self.set_output_val(0, val)
[docs] def update_event(self, inp=-1): if inp == 0: self._update_fields() self._update_value() elif inp == 1: self._update_value()
[docs]class IntRand_Node(Node): """ Generate a random non-negative integer. Inputs: high (int): Biggest possible integer. (Default is 1). length (int): How many random numbers to generate. (Default is 1.) Outputs: randint (int|numpy.ndarray): The randomly generated value(s). """ # this __doc__ string will be displayed as tooltip in the editor title = "IntRandom" init_inputs = [ NodeInputBP(dtype=dtypes.Integer(default=1, bounds=(10, 100)), label="high"), NodeInputBP(dtype=dtypes.Integer(default=1, bounds=(1, 100)), label="length"), ] init_outputs = [ NodeOutputBP(label="randint"), ] color = "#aabb44"
[docs] def update_event(self, inp=-1): val = np.random.randint(0, high=self.input(0), size=self.input(1)) self.set_output_val(0, val)
[docs]class JobName_Node(Node): """ Create job name for parameters. Inputs: base (str): The stem for the final name. (Default is "job_".) float (float): The parameter value to add to the name. Outputs: job_name (str): The base plus float sanitized into a valid job name. Todo: There has been some work in pyiron_base on getting a cleaner job name sanitizer, so lean on that. """ title = "JobName" init_inputs = [ NodeInputBP(dtype=dtypes.Char(default="job_"), label="base"), NodeInputBP(dtype=dtypes.Float(default=0), label="float"), ] init_outputs = [ NodeOutputBP(label="job_name"), ] color = "#aabb44"
[docs] def update_event(self, inp=-1): val = self.input(0) + f"{float(self.input(1))}".replace("-", "m").replace( ".", "p" ) self.set_output_val(0, val)
[docs]class Linspace_Node(Node): """ Generate a linear mesh in a given range using `np.linspace`. Inputs: min (int): The lower bound (inclusive). (Default is 1.) max (int): The upper bound (inclusive). (Default is 2.) steps (int): How many samples to take inside (min, max). (Default is 10.) Outputs: linspace (numpy.ndarray): A uniform sampling over the requested range. """ # this __doc__ string will be displayed as tooltip in the editor title = "Linspace" init_inputs = [ NodeInputBP(dtype=dtypes.Integer(default=1, bounds=(0, 100)), label="min"), NodeInputBP(dtype=dtypes.Integer(default=2, bounds=(0, 100)), label="max"), NodeInputBP(dtype=dtypes.Integer(default=10, bounds=(1, 100)), label="steps"), ] init_outputs = [ NodeOutputBP(label="linspace"), ] color = "#aabb44"
[docs] def place_event(self): super().place_event() self.update()
[docs] def update_event(self, inp=-1): val = np.linspace(self.input(0), self.input(1), self.input(2)) self.set_output_val(0, val)
[docs]class Plot3d_Node(Node): """ Plot a structure with NGLView. Inputs: structure (pyiron_atomistics.Atoms): The structure to plot. Outputs: plot3d (nglview.widget.NGLWidget): The plot object. structure (pyiron_atomistics.Atoms): The raw structure object passed in. """ title = "Plot3d" version = "v0.1" init_inputs = [ NodeInputBP(dtype=dtypes.Data(size="m"), label="structure"), ] init_outputs = [ NodeOutputBP(type_="data", label="plot3d"), NodeOutputBP(type_="data", label="structure"), ] color = "#5d95de"
[docs] def update_event(self, inp=-1): self.set_output_val(0, self.input(0).plot3d()) self.set_output_val(1, self.input(0))
[docs]class Matplot_Node(Node): """ A 2D matplotlib plot. Inputs: x (list|numpy.ndarray|...): Data for the x-axis. y (list|numpy.ndarray|...): Data for the y-axis. Outputs: fig (matplotlib.figure.Figure): The resulting figure after a `matplotlib.pyplot.plot` call on x and y. """ title = "MatPlot" version = "v0.1" init_inputs = [ NodeInputBP(dtype=dtypes.Data(size="m"), label="x"), NodeInputBP(dtype=dtypes.Data(size="m"), label="y"), ] init_outputs = [ NodeOutputBP(type_="data", label="fig"), ] color = "#5d95de"
[docs] def update_event(self, inp=-1): super().update_event() plt.ioff() fig = plt.figure() plt.clf() plt.plot(self.input(0), self.input(1)) self.set_output_val(0, fig) plt.ion()
[docs]class Sin_Node(Node): """ Call `numpy.sin` on a value. Inputs: x (int|float|list|numpy.ndarray|...): The value to sine transform. Outputs: sin (float|numpy.ndarray): The sine of x. """ title = "Sin" version = "v0.1" init_inputs = [ NodeInputBP(dtype=dtypes.Data(size="m"), label="x"), ] init_outputs = [ NodeOutputBP(label="sin"), ] color = "#5d95de"
[docs] def update_event(self, inp=-1): self.set_output_val(0, np.sin(self.input(0)))
[docs]class Result_Node(Node): """Simply shows a value converted to str""" version = "v0.1" title = "Result" init_inputs = [ NodeInputBP(type_="data"), ] color = "#c69a15" def __init__(self, params): super().__init__(params) self.val = None
[docs] def place_event(self): super().place_event() self.update()
[docs] def view_place_event(self): self.main_widget().show_val(self.val)
[docs] def update_event(self, inp=-1): self.val = self.input(0) if self.session.gui: self.main_widget().show_val(self.val)
[docs]class ForEach_Node(Node): title = "ForEach" version = "v0.1" init_inputs = [ NodeInputBP(type_="exec", label="start"), NodeInputBP(type_="exec", label="reset"), NodeInputBP(dtype=dtypes.List(), label="elements"), ] init_outputs = [ NodeOutputBP(label="loop", type_="exec"), NodeOutputBP(label="e", type_="data"), NodeOutputBP(label="finished", type_="exec"), ] color = "#b33a27" _count = 0
[docs] def update_event(self, inp=-1): if inp == 0: self._count += 1 if len(self.input(2)) > self._count: e = self.input(2)[self._count] self.set_output_val(1, e) self.exec_output(0) else: self.exec_output(2) elif inp > 0: self._count = 0 self.val = self._count
[docs]class ExecCounter_Node(DualNodeBase): title = "ExecCounter" version = "v0.1" init_inputs = [ NodeInputBP(type_="exec"), ] init_outputs = [ NodeOutputBP(type_="exec"), ] color = "#5d95de" def __init__(self, params): super().__init__(params, active=True) self._count = 0
[docs] def update_event(self, inp=-1): if self.active and inp == 0: self._count += 1 self.val = self._count elif not self.active: self.val = self.input(0)
[docs]class Click_Node(Node): title = "Click" version = "v0.1" main_widget_class = ButtonNodeWidget init_inputs = [] init_outputs = [NodeOutputBP(type_="exec")] color = "#99dd55"
[docs] def update_event(self, inp=-1): self.exec_output(0)
nodes = [ Project_Node, BulkStructure_Node, Repeat_Node, ApplyStrain_Node, Lammps_Node, JobName_Node, GenericOutput_Node, Plot3d_Node, IntRand_Node, Linspace_Node, Sin_Node, Result_Node, ExecCounter_Node, Matplot_Node, Click_Node, ForEach_Node, ]