Source code for ironflow.model.dtypes

# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.
"""
Initially wrapped dtypes to avoid mutable defaults, but after that expanded to
facilitate strict type checking when flow connections are made, and to add batching.

Node ports were overridden so that by default they come with an `Untyped` dtype, and
always have a batching flag.

Further, the dtypes are broken down into broad categories of `Data`, `List`, and
`Choice` with different behaviours under regular and batched conditions.

Warning:
    Any additional types defined here later need to be added to the list in
    `DType.from_str` to work with (de)serialization.

Implementation of Dtypes changes in ryvencore v0.4, so this file may be short-lived.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Optional

import numpy as np
from ryvencore.dtypes import DType as DTypeCore


[docs] def isiterable(obj): try: iter(obj) return True except TypeError: return False
[docs] def other_classes_are_subset(other, reference): return all(any(issubclass(o, ref) for ref in reference) for o in other)
[docs] class DType(DTypeCore, ABC): def __init__( self, default, bounds: tuple = None, doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): super().__init__( default=default, bounds=bounds, doc=doc, _load_state=_load_state, ) if _load_state is None: if isinstance(valid_classes, list): self.valid_classes = list(valid_classes) elif valid_classes is not None: self.valid_classes = [valid_classes] else: self.valid_classes = [] self.allow_none = allow_none self.batched = batched self.add_data("valid_classes") self.add_data("allow_none") self.add_data("batched")
[docs] @staticmethod def from_str(s): # Load local dtypes, not ryven dtypes for DTypeClass in [ Boolean, Choice, Data, Float, Integer, List, String, Untyped, ]: if s == "DType." + DTypeClass.__name__: return DTypeClass return None
def _classes_are_subset(self, other_classes): return other_classes_are_subset(other_classes, self.valid_classes)
[docs] def accepts(self, other: DType | Any | None): if isinstance(other, DType): return self._accepts_dtype(other) else: return self._accepts_instance(other)
@abstractmethod def _accepts_dtype(self, other: DType): pass def _accepts_instance(self, val: Any): if self.batched: return self._batch_accepts_instance(val) else: return self._accepts_none(val) or self._accepts_non_none_instance(val) @abstractmethod def _batch_accepts_instance(self, val: Any): pass def _accepts_none(self, val: Any): return val is None and self.allow_none @abstractmethod def _accepts_non_none_instance(self, val: Any): pass
[docs] def valid_val(self, val: Any): return self._accepts_instance(val)
def _surprise_none_possible(self, other: DType): return other.allow_none and not self.allow_none
[docs] class Untyped(DType): """ Untyped data always performs an instance-check when used as input and when it's used as output, other input nodes always perform an instance-check against it. That means it can't be used to pre-wire a graph that has missing data. When Untyped as an input receives output connections... Normally: - Accept anything. When batched: - Accept any value that is iterable. Untyped is always valid unless the value is None and None is not allowed. """ def __init__( self, doc: str = "", _load_state=None, batched=False, ): super().__init__( default=None, bounds=None, doc=doc, _load_state=_load_state, valid_classes=None, allow_none=True, batched=batched, ) def _accepts_dtype(self, other: DType): raise ValueError( f"Match checks to {self.__class__.__name__} should always be done by " f"value, not by dtype. Please contact a package maintainer and explain how" f"you got to this error." ) def _accepts_non_none_instance(self, val: Any): return True def _batch_accepts_instance(self, val: Any): return hasattr(val, "__iter__")
[docs] class Data(DType): """ For most types of data. When Data as an input receives output connections... Normally: - Data output: output valid classes must be a subset of input valid classes, and one of input or output classes must inherit from the other (or be the same). - Untyped output: output value must be an instance of valid classes - All else: Fail. When batched: - Batched Data output: same as the unbatched case, but now both are batched. - Untyped output: output value must be iterable and each element must be an instance of valid classes. - List output: output valid classes must be a subset of input valid classes. - All else: Fail. Data is valid when the value is an instance of the valid classes, or is None and None is allowed. """ def __init__( self, default=None, size: str = "m", doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): """ size: 's' / 'm' / 'l' """ self.size = size super().__init__( default=default, doc=doc, _load_state=_load_state, valid_classes=valid_classes, allow_none=allow_none, batched=batched, ) self.add_data("size") def _accepts_dtype(self, other: DType): if isinstance(other, Untyped): raise ValueError( f"Match checks against {Untyped.__class__.__name__} should always be " f"done by value, not by dtype. Please contact a package maintainer and " f"explain how you got to this error." ) elif ( isinstance(other, self.__class__) or isinstance(self, other.__class__) ) and other.batched == self.batched: return self._classes_are_subset( other.valid_classes ) and not self._surprise_none_possible(other) elif self.batched and isinstance(other, List) and not other.batched: return self._classes_are_subset(other.valid_classes) else: return False def _accepts_non_none_instance(self, val: Any): return any(isinstance(val, c) for c in self.valid_classes) def _batch_accepts_instance(self, val: Any): if hasattr(val, "__iter__"): if any(v is None for v in val) and not self.allow_none: return False else: return self._classes_are_subset( set([type(v) for v in val if v is not None]) ) else: return False
[docs] class Integer(Data): def __init__( self, default: int = 0, bounds: tuple = None, doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): self.bounds = bounds super().__init__( default=default, doc=doc, _load_state=_load_state, valid_classes=[int, np.integer] if valid_classes is None else valid_classes, allow_none=allow_none, batched=batched, ) self.add_data("bounds")
[docs] class Float(Data): def __init__( self, default: float = 0.0, bounds: tuple = None, decimals: int = 10, doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): self.bounds = bounds self.decimals = decimals super().__init__( default=default, doc=doc, _load_state=_load_state, valid_classes=( [float, np.floating] if valid_classes is None else valid_classes ), allow_none=allow_none, batched=batched, ) self.add_data("bounds") self.add_data("decimals")
[docs] class Boolean(Data): def __init__( self, default: bool = False, doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): super().__init__( default=default, doc=doc, _load_state=_load_state, valid_classes=[bool, np.bool_] if valid_classes is None else valid_classes, allow_none=allow_none, batched=batched, )
[docs] class String(Data): def __init__( self, default: str = "", doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): super().__init__( default=default, doc=doc, _load_state=_load_state, valid_classes=[str, np.str_] if valid_classes is None else valid_classes, allow_none=allow_none, batched=batched, )
[docs] class Choice(DType): """ Data that must be chosen from among a list of items. When Choice as an input receives output connections... Normally: - Data output: output valid classes must be a subset. - Untyped output: output value must be in the items list. - All else: Fail. When batched: - Batched Data output: output valid classes must be a subset. - List output: output valid classes must be a subset. - Untyped: output value must be iterable, and each element must be in the items list. - All else: Fail. Choice is valid when the value is in the items list, or value is None and None is allowed. Note that when making Data (or List) connections, the connection may be allowed but still result in an invalid value state (in cases where the output value does not match the input items list). """ def __init__( self, default=None, items: Optional[list] = None, doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): self.items = items if items is not None else [] super().__init__( default=default, doc=doc, _load_state=_load_state, valid_classes=valid_classes, allow_none=allow_none, batched=batched, ) self.add_data("items") def _accepts_dtype(self, other: DType): # TODO: Temporary code duplication while splitting Data and Choice if isinstance(other, Untyped): raise ValueError( f"Match checks against {Untyped.__class__.__name__} should always be " f"done by value, not by dtype. Please contact a package maintainer and " f"explain how you got to this error." ) if self.batched: dtype_ok = (isinstance(other, List) and not other.batched) or ( isinstance(other, Data) and other.batched ) classes_ok = self._classes_are_subset(other.valid_classes) return dtype_ok and classes_ok and not self._surprise_none_possible(other) elif isinstance(other, Data) and not other.batched: return self._classes_are_subset( other.valid_classes ) and not self._surprise_none_possible(other) else: return False def _accepts_non_none_instance(self, val: Any): return val in self.items def _batch_accepts_instance(self, val: Any): return isinstance(val, (list, np.ndarray)) and all( self._accepts_non_none_instance(v) for v in val )
[docs] class List(DType): """ Data that is explicitly iterable. When List as an input receives output connections... Normally: - List output: output valid classes must be a subset. - Batched Data output: output valid classes must be a subset. - Untyped output: output value must be iterable and each element must be an instance of valid classes. - All else: Fail. When batched: - Batched List output: output valid classes must be a subset. - Untyped: output value must be iterable, each element must be iterable, and each element's element must be an instance of a valid class. - All else: Fail. List is valid when the value is iterable and all elements are instances of the valid classes, or the value is None and None is allowed. Note: `allow_none` in this case determines whether the _entire dtype value_ may be `None`. If you want to specify that the list-like object itself may _contain_ `None` values, add `type(None)` to the `valid_classes`. """ def __init__( self, default: Optional[list] = None, doc: str = "", _load_state=None, valid_classes=None, allow_none=False, batched=False, ): super().__init__( default=default, doc=doc, _load_state=_load_state, valid_classes=list if valid_classes is None else valid_classes, allow_none=allow_none, batched=batched, ) def _accepts_dtype(self, other: DType): if self.batched: return ( isinstance(other, List) and other.batched and self._classes_are_subset(other.valid_classes) ) elif isinstance(other, List) or (isinstance(other, Data) and other.batched): # TODO: Only other unbatched lists should be accepted to conform to spec # At the moment, this is a very useful bug, since it lets us pass # batched data to the `Transpose` and `Slice` nodes to modify them # The correct fix is to introduce a new Matrix DType, of which # List is a special case return self._classes_are_subset(other.valid_classes) else: return False def _accepts_non_none_instance(self, val: Any): return isiterable(val) and all( any(isinstance(v, c) for c in self.valid_classes) for v in val ) def _batch_accepts_instance(self, val: Any): return isiterable(val) and all( self._accepts_none(v) or self._accepts_non_none_instance(v) for v in val )