Source code for pysubgroup.subgroup_description

"""
Created on 28.04.2016

@author: lemmerfn
"""
import copy
import weakref
from abc import ABC, abstractmethod
from functools import total_ordering
from itertools import chain

import numpy as np

import pysubgroup as ps


[docs] @total_ordering class SelectorBase(ABC): """Base class for selectors, ensuring each selector instance is unique.""" # selector cache __refs__ = weakref.WeakSet() def __new__(cls, *args, **kwargs): """Create a new SelectorBase instance, ensuring uniqueness. Ensures that each selector only exists once by caching instances. """ # create temporary selector tmp = super().__new__(cls) tmp.set_descriptions(*args, **kwargs) # save original arguments # NOTE: this is a fix for pickle # so we can call `__getnewargs_ex__` with the right arguments # TODO: this may have unintended side effects if args, # kwargs are large or volatile (I don't think we have that yet though) tmp.__new_args__ = args, kwargs # check if selector is already in cache (__refs__) # if so, return cached instance if tmp not in SelectorBase.__refs__: return tmp # new selector # (not sure why we never have the below case) for ref in SelectorBase.__refs__: # pragma no branch okay if ref == tmp: return ref # if not return return tmp # pragma: no cover def __getnewargs_ex__(self): # pylint: disable=invalid-getnewargs-ex-returned """Return arguments necessary to recreate the object during unpickling.""" tmp_args = self.__new_args__ del self.__new_args__ return tmp_args def __init__(self): """Initialize the SelectorBase and add it to the cache.""" # add selector to cache # TODO: why not do this in `__new__`, # then it would be all together in one function? SelectorBase.__refs__.add(self) def __eq__(self, other): """Check equality based on the string representation.""" if other is None: # pragma: no cover return False return repr(self) == repr(other) def __lt__(self, other): """Define less-than comparison based on the string representation.""" return repr(self) < repr(other) def __hash__(self): """Return the hash value.""" return self._hash # pylint: disable=no-member
[docs] @abstractmethod def set_descriptions(self, *args, **kwargs): """Set the descriptions for the selector.""" pass # pragma: no cover
[docs] def get_cover_array_and_size(subgroup, data_len=None, data=None): """Compute the cover array and its size for a given subgroup. Parameters: subgroup: The subgroup for which to compute the cover array and size. data_len: Optional length of the data. data: Optional data. Returns: Tuple of (cover array, size). """ if hasattr(subgroup, "representation"): cover_arr = subgroup size = subgroup.size_sg elif isinstance(subgroup, slice): cover_arr = subgroup if data_len is None: if type(data).__name__ == "DataFrame": data_len = len(data) else: raise ValueError( "if you pass a slice, you need to pass either data_len or data" ) # https://stackoverflow.com/questions/36188429/retrieve-length-of-slice-from-slice-object-in-python size = len(range(*subgroup.indices(data_len))) elif hasattr(subgroup, "__array_interface__"): cover_arr = subgroup type_char = subgroup.__array_interface__["typestr"][1] if type_char == "b": # boolean indexing is used size = np.count_nonzero(cover_arr) elif type_char in ("u", "i"): # integer indexing size = subgroup.__array_interface__["shape"][0] else: raise NotImplementedError( f"Currently a typechar of {type_char} is not supported." ) else: assert type(data).__name__ == "DataFrame", str(type(data)) cover_arr = subgroup.covers(data) size = np.count_nonzero(cover_arr) return cover_arr, size
[docs] def get_size(subgroup, data_len=None, data=None): """Compute the size of the cover array for a given subgroup. Parameters: subgroup: The subgroup for which to compute the size. data_len: Optional length of the data. data: Optional data. Returns: Size of the cover array. """ if hasattr(subgroup, "representation"): size = subgroup.size_sg elif isinstance(subgroup, slice): if data_len is None: if type(data).__name__ == "DataFrame": data_len = len(data) else: raise ValueError( "if you pass a slice, you need to pass either data_len or data" ) # https://stackoverflow.com/questions/36188429/retrieve-length-of-slice-from-slice-object-in-python size = len(range(*subgroup.indices(data_len))) elif hasattr(subgroup, "__array_interface__"): type_char = subgroup.__array_interface__["typestr"][1] if type_char == "b": # boolean indexing is used size = np.count_nonzero(subgroup) elif type_char == "u" or type_char == "i": # integer indexing size = subgroup.__array_interface__["shape"][0] else: raise NotImplementedError( f"Currently a typechar of {type_char} is not supported." ) else: assert type(data).__name__ == "DataFrame" size = np.count_nonzero(subgroup.covers(data)) return size
[docs] def pandas_sparse_eq(col, value): """Compare a pandas sparse column to a value. Parameters: col: pandas Series with SparseArray data. value: The value to compare with. Returns: A pandas SparseArray of booleans indicating where col equals value. """ import pandas as pd # pylint: disable=import-outside-toplevel from pandas._libs.sparse import ( IntIndex, # pylint: disable=import-outside-toplevel, no-name-in-module ) col_arr = col.array is_same_value = col_arr.sp_values == value new_index_arr = col_arr.sp_index.indices[is_same_value] index = IntIndex(len(col), new_index_arr) return pd.arrays.SparseArray( np.ones(len(new_index_arr), dtype=bool), index, col_arr.fill_value == value, dtype=bool, )
[docs] class EqualitySelector(SelectorBase): """Selector that checks for equality with a specific value.""" def __init__(self, attribute_name, attribute_value, selector_name=None): if attribute_name is None: raise TypeError() if attribute_value is None: raise TypeError() # TODO: this is redundant due to `__new__` and `set_descriptions` self._attribute_name = attribute_name self._attribute_value = attribute_value self._selector_name = selector_name self.set_descriptions( self._attribute_name, self._attribute_value, self._selector_name ) super().__init__() @property def attribute_name(self): """Name of the attribute.""" return self._attribute_name @property def attribute_value(self): """Value of the attribute to compare for equality.""" return self._attribute_value
[docs] def set_descriptions( self, attribute_name, attribute_value, selector_name=None ): # pylint: disable=arguments-differ """Set the descriptions (query, string, hash) for the selector.""" self._hash, self._query, self._string = EqualitySelector.compute_descriptions( attribute_name, attribute_value, selector_name=selector_name )
[docs] @classmethod def compute_descriptions(cls, attribute_name, attribute_value, selector_name): """Compute the descriptions (hash, query, string) for the selector.""" if isinstance(attribute_value, (str, bytes)): query = str(attribute_name) + "==" + "'" + str(attribute_value) + "'" elif attribute_value is None: query = str(attribute_name) + " is None" elif np.isnan(attribute_value): query = attribute_name + ".isnull()" else: query = str(attribute_name) + "==" + str(attribute_value) if selector_name is not None: string_ = selector_name else: string_ = query hash_value = hash(query) return (hash_value, query, string_)
def __repr__(self): """Representation of the selector as a query string.""" return self._query
[docs] def covers(self, data): """Determine which instances in data are covered by this selector. Parameters: data: pandas DataFrame containing the data. Returns: A boolean array indicating which instances are covered. """ import pandas as pd # pylint: disable=import-outside-toplevel column = data[self.attribute_name] if isinstance(column.dtype, pd.SparseDtype): row = column if not pd.isnull(self.attribute_value): return pandas_sparse_eq(column, self.attribute_value) else: row = column.to_numpy() if pd.isnull(self.attribute_value): return pd.isnull(row) return row == self.attribute_value
def __str__(self, open_brackets="", closing_brackets=""): """String representation of the selector, optionally with brackets.""" return open_brackets + self._string + closing_brackets @property def selectors(self): """Return the selector itself as a tuple (for compatibility).""" return (self,)
[docs] @staticmethod def from_str(s): """Create an EqualitySelector from a string representation. Parameters: s: String representation of the selector. Returns: An EqualitySelector instance. """ s = s.strip() attribute_name, attribute_value = s.split("==") if attribute_value[0] == "'" and attribute_value[-1] == "'": if attribute_value.startswith("'b'") and attribute_value.endswith("''"): attribute_value = str.encode(attribute_value[3:-2]) else: attribute_value = attribute_value[1:-1] try: attribute_value = int(attribute_value) except ValueError: try: attribute_value = float(attribute_value) except ValueError: try: attribute_value = ps.str_to_bool(attribute_value) except ValueError: pass return EqualitySelector(attribute_name, attribute_value)
[docs] class NegatedSelector(SelectorBase): """Selector that negates another selector.""" def __init__(self, selector): # TODO: this is redundant due to `__new__` and `set_descriptions` self._selector = selector self.set_descriptions(selector) super().__init__()
[docs] def covers(self, data_instance): """Determine which instances are not covered by the underlying selector. Parameters: data_instance: pandas DataFrame containing the data. Returns: A boolean array indicating which instances are not covered. """ return np.logical_not(self._selector.covers(data_instance))
def __repr__(self): """Representation of the negated selector as a query string.""" return self._query def __str__(self, open_brackets="", closing_brackets=""): """String representation of the negated selector.""" return "NOT " + self._selector.__str__(open_brackets, closing_brackets)
[docs] def set_descriptions(self, selector): # pylint: disable=arguments-differ """Set the descriptions (query, hash) for the negated selector.""" self._query = "(not " + repr(selector) + ")" self._hash = hash(repr(self))
@property def attribute_name(self): """Name of the attribute.""" return self._selector.attribute_name @property def selectors(self): """Return the selector itself as a tuple (for compatibility).""" return (self,)
# Including the lower bound, excluding the upper_bound
[docs] class IntervalSelector(SelectorBase): """Selector that checks if a value is within an interval.""" def __init__(self, attribute_name, lower_bound, upper_bound, selector_name=None): assert lower_bound < upper_bound # TODO: this is redundant due to `__new__` and `set_descriptions` self._attribute_name = attribute_name self._lower_bound = lower_bound self._upper_bound = upper_bound self.selector_name = selector_name self.set_descriptions(attribute_name, lower_bound, upper_bound, selector_name) super().__init__() @property def attribute_name(self): """Name of the attribute.""" return self._attribute_name @property def lower_bound(self): """Lower bound of the interval (inclusive).""" return self._lower_bound @property def upper_bound(self): """Upper bound of the interval (exclusive).""" return self._upper_bound
[docs] def covers(self, data_instance): """Determine which instances are covered by this interval selector. Parameters: data_instance: pandas DataFrame containing the data. Returns: A boolean array indicating which instances are within the interval. """ val = data_instance[self.attribute_name].to_numpy() return np.logical_and((val >= self.lower_bound), (val < self.upper_bound))
def __repr__(self): """Representation of the interval selector as a query string.""" return self._query def __hash__(self): return self._hash def __str__(self): """String representation of the interval selector.""" return self._string
[docs] @classmethod def compute_descriptions( cls, attribute_name, lower_bound, upper_bound, selector_name=None ): """Compute the descriptions (hash, query, string) for the interval selector.""" if selector_name is None: _string = cls.compute_string( attribute_name, lower_bound, upper_bound, rounding_digits=2 ) else: _string = selector_name _query = cls.compute_string( attribute_name, lower_bound, upper_bound, rounding_digits=None ) _hash = hash(_query) return (_hash, _query, _string)
[docs] def set_descriptions( self, attribute_name, lower_bound, upper_bound, selector_name=None ): # pylint: disable=arguments-differ """Set the descriptions (hash, query, string) for the interval selector.""" self._hash, self._query, self._string = IntervalSelector.compute_descriptions( attribute_name, lower_bound, upper_bound, selector_name=selector_name )
[docs] @classmethod def compute_string(cls, attribute_name, lower_bound, upper_bound, rounding_digits): """Compute the string representation of the interval selector.""" if rounding_digits is None: formatter = "{}" else: formatter = "{0:." + str(rounding_digits) + "f}" ub = upper_bound lb = lower_bound if ub % 1: ub = formatter.format(ub) if lb % 1: lb = formatter.format(lb) if lower_bound == float("-inf") and upper_bound == float("inf"): repre = str(attribute_name) + " = anything" elif lower_bound == float("-inf"): repre = str(attribute_name) + "<" + str(ub) elif upper_bound == float("inf"): repre = str(attribute_name) + ">=" + str(lb) else: repre = str(attribute_name) + ": [" + str(lb) + ":" + str(ub) + "[" return repre
[docs] @staticmethod def from_str(s): """Create an IntervalSelector from a string representation. Parameters: s: String representation of the interval selector. Returns: An IntervalSelector instance. """ s = s.strip() if s.endswith(" = anything"): return IntervalSelector( s[: -len(" = anything")], float("-inf"), float("+inf") ) if "<" in s: attribute_name, ub = s.split("<") try: return IntervalSelector(attribute_name.strip(), float("-inf"), int(ub)) except ValueError: return IntervalSelector( attribute_name.strip(), float("-inf"), float(ub) ) if ">=" in s: attribute_name, lb = s.split(">=") try: return IntervalSelector(attribute_name.strip(), int(lb), float("inf")) except ValueError: return IntervalSelector(attribute_name.strip(), float(lb), float("inf")) if s.count(":") == 2: attribute_name, lb, ub = s.split(":") lb = lb.strip()[1:] ub = ub.strip()[:-1] try: return IntervalSelector(attribute_name.strip(), int(lb), int(ub)) except ValueError: return IntervalSelector(attribute_name.strip(), float(lb), float(ub)) else: raise ValueError(f"string {s} could not be converted to IntervalSelector")
@property def selectors(self): """Return the selector itself as a tuple (for compatibility).""" return (self,)
[docs] def create_selectors(data, nbins=5, intervals_only=True, ignore=None): """Create a list of selectors for all attributes in the data. Parameters: data: pandas DataFrame containing the data. nbins: Number of bins to use for numeric attributes. intervals_only: If True, only create interval selectors for numeric attributes. ignore: List of attribute names to ignore. Returns: List of selectors. """ if ignore is None: ignore = [] sels = create_nominal_selectors(data, ignore) sels.extend(create_numeric_selectors(data, nbins, intervals_only, ignore=ignore)) return sels
[docs] def create_nominal_selectors(data, ignore=None): """Create equality selectors for nominal attributes. Parameters: data: pandas DataFrame containing the data. ignore: List of attribute names to ignore. Returns: List of EqualitySelector instances. """ if ignore is None: ignore = [] nominal_selectors = [] # for attr_name in [ # x for x in data.select_dtypes(exclude=['number']).columns.values # if x not in ignore]: # nominal_selectors.extend( # create_nominal_selectors_for_attribute(data, attr_name)) nominal_dtypes = data.select_dtypes(exclude=["number"]) dtypes = data.dtypes # print(dtypes) for attr_name in [x for x in nominal_dtypes.columns.values if x not in ignore]: nominal_selectors.extend( create_nominal_selectors_for_attribute(data, attr_name, dtypes) ) return nominal_selectors
[docs] def create_nominal_selectors_for_attribute(data, attribute_name, dtypes=None): """Create equality selectors for a nominal attribute. Parameters: data: pandas DataFrame containing the data. attribute_name: Name of the attribute. dtypes: Data types of the data columns. Returns: List of EqualitySelector instances for the attribute. """ import pandas as pd # pylint: disable=import-outside-toplevel nominal_selectors = [] for val in pd.unique(data[attribute_name]): nominal_selectors.append(EqualitySelector(attribute_name, val)) # setting the is_bool flag for selector if dtypes is None: dtypes = data.dtypes if dtypes[attribute_name] == "bool": for s in nominal_selectors: s.is_bool = True return nominal_selectors
[docs] def create_numeric_selectors( data, nbins=5, intervals_only=True, weighting_attribute=None, ignore=None ): """Create selectors for numeric attributes. Parameters: data: pandas DataFrame containing the data. nbins: Number of bins to use for discretization. intervals_only: If True, only create interval selectors. weighting_attribute: Optional attribute for weighting. ignore: List of attribute names to ignore. Returns: List of numeric selectors. """ if ignore is None: ignore = [] # pragma: no cover numeric_selectors = [] for attr_name in [ x for x in data.select_dtypes(include=["number"]).columns.values if x not in ignore ]: numeric_selectors.extend( create_numeric_selectors_for_attribute( data, attr_name, nbins, intervals_only, weighting_attribute ) ) return numeric_selectors
[docs] def create_numeric_selectors_for_attribute( data, attr_name, nbins=5, intervals_only=True, weighting_attribute=None ): """Create selectors for a numeric attribute. Parameters: data: pandas DataFrame containing the data. attr_name: Name of the attribute. nbins: Number of bins to use for discretization. intervals_only: If True, only create interval selectors. weighting_attribute: Optional attribute for weighting. Returns: List of numeric selectors for the attribute. """ import pandas as pd # pylint: disable=import-outside-toplevel numeric_selectors = [] if isinstance(data[attr_name].dtype, pd.SparseDtype): numeric_selectors.append( EqualitySelector(attr_name, data[attr_name].sparse.fill_value) ) dense_data = data[attr_name].sparse.sp_values data_not_null = dense_data[pd.notnull(dense_data)] uniqueValues = np.unique(data_not_null) if len(data_not_null) < len(dense_data): numeric_selectors.append(EqualitySelector(attr_name, np.nan)) else: data_not_null = data[data[attr_name].notnull()] uniqueValues = np.unique(data_not_null[attr_name]) if len(data_not_null) < len(data): numeric_selectors.append(EqualitySelector(attr_name, np.nan)) if len(uniqueValues) <= nbins: for val in uniqueValues: numeric_selectors.append(EqualitySelector(attr_name, val)) else: cutpoints = ps.equal_frequency_discretization( data, attr_name, nbins, weighting_attribute ) if intervals_only: old_cutpoint = float("-inf") for c in cutpoints: numeric_selectors.append(IntervalSelector(attr_name, old_cutpoint, c)) old_cutpoint = c numeric_selectors.append( IntervalSelector(attr_name, old_cutpoint, float("inf")) ) else: for c in cutpoints: numeric_selectors.append(IntervalSelector(attr_name, c, float("inf"))) numeric_selectors.append(IntervalSelector(attr_name, float("-inf"), c)) return numeric_selectors
[docs] def remove_target_attributes(selectors, target): """Remove selectors that are based on target attributes. Parameters: selectors: List of selectors. target: The target object with get_attributes method. Returns: List of selectors not based on target attributes. """ return [ sel for sel in selectors if sel.attribute_name not in target.get_attributes() ]
############## # Boolean expressions ##############
[docs] class BooleanExpressionBase(ABC): """Base class for boolean expressions (conjunctions and disjunctions).""" def __or__(self, other): """Override the '|' operator to create a new expression with logical OR.""" tmp = copy.copy(self) tmp.append_or(other) return tmp def __and__(self, other): """Override the '&' operator to create a new expression with logical AND.""" tmp = copy.copy(self) tmp.append_and(other) return tmp
[docs] @abstractmethod def append_and(self, to_append): """Append a selector or expression using logical AND.""" pass
[docs] @abstractmethod def append_or(self, to_append): """Append a selector or expression using logical OR.""" pass
@abstractmethod def __copy__(self): """Create a copy of the boolean expression.""" pass
[docs] @total_ordering class Conjunction(BooleanExpressionBase): """Conjunction of selectors (logical AND).""" def __init__(self, selectors): self._repr = None self._hash = None try: it = iter(selectors) self._selectors = list(it) except TypeError: self._selectors = [selectors]
[docs] def covers(self, instance): """Determine which instances are covered by the conjunction. Parameters: instance: pandas DataFrame containing the data. Returns: A boolean array indicating which instances are covered. """ # empty description ==> return a list of all '1's if not self._selectors: return np.full(len(instance), True, dtype=bool) # non-empty description return np.all([sel.covers(instance) for sel in self._selectors], axis=0)
def __len__(self): """Return the number of selectors in the conjunction.""" return len(self._selectors) def __str__(self, open_brackets="", closing_brackets="", and_term=" AND "): """String representation of the conjunction.""" if not self._selectors: return "Dataset" attrs = sorted(str(sel) for sel in self._selectors) return "".join((open_brackets, and_term.join(attrs), closing_brackets)) def __repr__(self): """Representation of the conjunction.""" if self._repr is not None: return self._repr else: self._repr = self._compute_repr() return self._repr def __eq__(self, other): """Check equality based on the string representation.""" return repr(self) == repr(other) def __lt__(self, other): """Define less-than comparison based on the string representation.""" return repr(self) < repr(other) def __hash__(self): """Return the hash value.""" if self._hash is not None: return self._hash else: self._hash = self._compute_hash() return self._hash def _compute_repr(self): """Compute the representation of the conjunction.""" if not self._selectors: return "True" reprs = sorted(repr(sel) for sel in self._selectors) return "(" + " and ".join(reprs) + ")" def _compute_hash(self): """Compute the hash of the conjunction.""" return hash(repr(self)) def _invalidate_representations(self): """Invalidate cached representations.""" self._repr = None self._hash = None
[docs] def append_and(self, to_append): """Append a selector or conjunction using logical AND.""" if isinstance(to_append, SelectorBase): self._selectors.append(to_append) elif isinstance(to_append, Conjunction): self._selectors.extend(to_append.selectors) else: self._selectors.extend(to_append) self._invalidate_representations()
[docs] def append_or(self, to_append): """Append a selector or expression using logical OR (not supported).""" raise RuntimeError( "Or operations are not supported by a pure Conjunction. Consider using DNF." )
[docs] def pop_and(self): """Remove and return the last selector added using AND.""" return self._selectors.pop()
[docs] def pop_or(self): """Pop operation for OR is not supported in Conjunction.""" raise RuntimeError( "Or operations are not supported by a pure Conjunction. Consider using DNF." )
def __copy__(self): """Create a copy of the conjunction.""" cls = self.__class__ result = cls.__new__(cls) result.__dict__.update(self.__dict__) result._selectors = list(self._selectors) return result @property def depth(self): """Return the number of selectors in the conjunction.""" return len(self._selectors) @property def selectors(self): """Return the selectors in the conjunction as a tuple.""" return tuple(chain.from_iterable(sel.selectors for sel in self._selectors))
[docs] @staticmethod def from_str(s): """Create a Conjunction from a string representation. Parameters: s: String representation of the conjunction. Returns: A Conjunction instance. """ if s.strip() == "Dataset": return Conjunction([]) selector_strings = s.split(" AND ") selectors = [] for selector_string in selector_strings: selector_string = selector_string.strip() if "==" in selector_string: selectors.append(EqualitySelector.from_str(selector_string)) else: selectors.append(IntervalSelector.from_str(selector_string)) return Conjunction(selectors)
[docs] @total_ordering class Disjunction(BooleanExpressionBase): """Disjunction of selectors (logical OR).""" def __init__(self, selectors=None): if isinstance(selectors, (list, tuple)): self._selectors = selectors elif selectors is None: self._selectors = [] else: self._selectors = [selectors]
[docs] def covers(self, instance): """Determine which instances are covered by the disjunction. Parameters: instance: pandas DataFrame containing the data. Returns: A boolean array indicating which instances are covered. """ # empty description ==> return a list of all '0's if not self._selectors: return np.full(len(instance), False, dtype=bool) # non-empty description return np.any([sel.covers(instance) for sel in self._selectors], axis=0)
def __len__(self): """Return the number of selectors in the disjunction.""" return len(self._selectors) def __str__(self, open_brackets="", closing_brackets="", or_term=" OR "): """String representation of the disjunction.""" if not self._selectors: return "Empty" # pragma: no cover attrs = sorted(str(sel) for sel in self._selectors) return "".join((open_brackets, or_term.join(attrs), closing_brackets)) def __repr__(self): """Representation of the disjunction.""" if not self._selectors: return "True" reprs = sorted(repr(sel) for sel in self._selectors) return "".join(("(", " or ".join(reprs), ")")) def __eq__(self, other): """Check equality based on the string representation.""" return repr(self) == repr(other) def __lt__(self, other): """Define less-than comparison based on the string representation.""" return repr(self) < repr(other) def __hash__(self): """Return the hash value.""" return hash(repr(self))
[docs] def append_and(self, to_append): """Append a selector or expression using logical AND (not supported).""" raise RuntimeError( "And operations are not supported by a pure Conjunction. " "Consider using DNF." )
[docs] def append_or(self, to_append): """Append a selector or disjunction using logical OR.""" if isinstance(to_append, Disjunction): self._selectors.extend(to_append.selectors) return try: self._selectors.extend(to_append) except TypeError: self._selectors.append(to_append)
def __copy__(self): """Create a copy of the disjunction.""" cls = self.__class__ result = cls.__new__(cls) result.__dict__.update(self.__dict__) result._selectors = copy.copy(self._selectors) return result @property def selectors(self): """Return the selectors in the disjunction as a tuple.""" return tuple(chain.from_iterable(sel.selectors for sel in self._selectors))
[docs] class DNF(Disjunction): """Disjunctive Normal Form expression.""" def __init__(self, selectors=None): if selectors is None: selectors = [] super().__init__([]) self.append_or(selectors) @staticmethod def _ensure_pure_conjunction(to_append): """Ensure that the appended expression is a pure conjunction.""" if isinstance(to_append, Conjunction): return to_append elif isinstance(to_append, SelectorBase): return Conjunction(to_append) else: it = iter(to_append) if all(isinstance(sel, SelectorBase) for sel in to_append): return Conjunction(it) else: raise ValueError( "DNFs only accept an iterable of Selectors" ) # pragma: no cover
[docs] def append_or(self, to_append): """Append a selector or conjunction using logical OR.""" if isinstance(to_append, Disjunction): to_append = to_append.selectors try: it = iter(to_append) conjunctions = [DNF._ensure_pure_conjunction(part) for part in it] except TypeError: conjunctions = DNF._ensure_pure_conjunction(to_append) super().append_or(conjunctions)
[docs] def append_and(self, to_append): """Append a selector using logical AND to all conjunctions.""" if isinstance(to_append, Disjunction): raise NotImplementedError( "Appeding a disjunction to DNF is not implemented" ) conj = DNF._ensure_pure_conjunction(to_append) if len(self._selectors) > 0: for conjunction in self._selectors: conjunction.append_and(conj) else: self._selectors.append(conj)
[docs] def pop_and(self): """Remove and return the last selector added using AND from all conjunctions.""" out_list = [s.pop_and() for s in self._selectors] return_val = out_list[0] if all(x == return_val for x in out_list): return return_val else: for to_append, conj in zip(out_list, self._selectors): conj.append_and(to_append) raise RuntimeError("pop_and failed as the result was inconsistent")