import numpy as np
from sympy.core import S
from sympy.core.numbers import Float, Integer, Rational
from sympy.simplify.simplify import simplify as _simplify
# The following functions are taken from sympy.matrices.utilities
# and sympy.matrices.determinant and sympy.matrices.reductions.
# This bit of sympy appears to be in a high state of flux,
# so we copy the functions here for the time being.
def _iszero(x):
"""Returns True if x is zero."""
return getattr(x, "is_zero", None)
def _find_reasonable_pivot(col, iszerofunc=_iszero, simpfunc=_simplify):
"""
Find the lowest index of an item in ``col`` that is
suitable for a pivot. If ``col`` consists only of
Floats, the pivot with the largest norm is returned.
Otherwise, the first element where ``iszerofunc`` returns
False is used. If ``iszerofunc`` doesn't return false,
items are simplified and retested until a suitable
pivot is found.
Returns a 4-tuple
(pivot_offset, pivot_val, assumed_nonzero, newly_determined)
where pivot_offset is the index of the pivot, pivot_val is
the (possibly simplified) value of the pivot, assumed_nonzero
is True if an assumption that the pivot was non-zero
was made without being proved, and newly_determined are
elements that were simplified during the process of pivot
finding."""
newly_determined = []
col = list(col)
# a column that contains a mix of floats and integers
# but at least one float is considered a numerical
# column, and so we do partial pivoting
if all(isinstance(x, (Float, Integer)) for x in col) and any(
isinstance(x, Float) for x in col
):
col_abs = [abs(x) for x in col]
max_value = max(col_abs)
if iszerofunc(max_value):
# just because iszerofunc returned True, doesn't
# mean the value is numerically zero. Make sure
# to replace all entries with numerical zeros
if max_value != 0:
newly_determined = [(i, 0) for i, x in enumerate(col) if x != 0]
return (None, None, False, newly_determined)
index = col_abs.index(max_value)
return (index, col[index], False, newly_determined)
# PASS 1 (iszerofunc directly)
possible_zeros = []
for i, x in enumerate(col):
is_zero = iszerofunc(x)
# is someone wrote a custom iszerofunc, it may return
# BooleanFalse or BooleanTrue instead of True or False,
# so use == for comparison instead of `is`
if is_zero is False:
# we found something that is definitely not zero
return (i, x, False, newly_determined)
possible_zeros.append(is_zero)
# by this point, we've found no certain non-zeros
if all(possible_zeros):
# if everything is definitely zero, we have
# no pivot
return (None, None, False, newly_determined)
# PASS 2 (iszerofunc after simplify)
# we haven't found any for-sure non-zeros, so
# go through the elements iszerofunc couldn't
# make a determination about and opportunistically
# simplify to see if we find something
for i, x in enumerate(col):
if possible_zeros[i] is not None:
continue
simped = simpfunc(x)
is_zero = iszerofunc(simped)
if is_zero is True or is_zero is False:
newly_determined.append((i, simped))
if is_zero is False:
return (i, simped, False, newly_determined)
possible_zeros[i] = is_zero
# after simplifying, some things that were recognized
# as zeros might be zeros
if all(possible_zeros):
# if everything is definitely zero, we have
# no pivot
return (None, None, False, newly_determined)
# PASS 3 (.equals(0))
# some expressions fail to simplify to zero, but
# ``.equals(0)`` evaluates to True. As a last-ditch
# attempt, apply ``.equals`` to these expressions
for i, x in enumerate(col):
if possible_zeros[i] is not None:
continue
if x.equals(S.Zero):
# ``.iszero`` may return False with
# an implicit assumption (e.g., ``x.equals(0)``
# when ``x`` is a symbol), so only treat it
# as proved when ``.equals(0)`` returns True
possible_zeros[i] = True
newly_determined.append((i, S.Zero))
if all(possible_zeros):
return (None, None, False, newly_determined)
# at this point there is nothing that could definitely
# be a pivot. To maintain compatibility with existing
# behavior, we'll assume that an illdetermined thing is
# non-zero. We should probably raise a warning in this case
i = possible_zeros.index(None)
return (i, col[i], True, newly_determined)
def _row_reduce_list(
mat,
rows,
cols,
iszerofunc,
simpfunc,
normalize_last=True,
normalize=True,
zero_above=True,
):
"""
Row reduce a flat list representation of a matrix and return a tuple
(rref_matrix, pivot_cols, swaps) where ``rref_matrix`` is a flat list,
``pivot_cols`` are the pivot columns and ``swaps`` are any row swaps that
were used in the process of row reduction.
:param mat: list of matrix elements, must be ``rows`` * ``cols`` in length.
:type mat: list
:param rows: number of rows in flat list representation.
:type rows: integer
:param cols: number of columns in flat list representation.
:type cols: integer
:param iszerofunc: determines if an entry can be used as a pivot.
:type iszerofunc: function
:param simpfunc: used to simplify elements and test if they are
zero if ``iszerofunc`` returns `None`.
:type simpfunc: function
:param normalize_last: indicates where all row reduction should
happen in a fraction-free manner and then the rows are
normalized (so that the pivots are 1), or whether
rows should be normalized along the way (like the naive
row reduction algorithm).
:type normalize_last: bool
:param normalize: whether pivot rows should be normalized so that
the pivot value is 1
:type normalize: bool
:param zero_above: whether entries above the pivot should be zeroed.
If ``zero_above=False``, an echelon matrix will be returned.
:type zero_above: bool
"""
def get_col(i):
return mat[i::cols]
def row_swap(i, j):
mat[i * cols : (i + 1) * cols], mat[j * cols : (j + 1) * cols] = (
mat[j * cols : (j + 1) * cols],
mat[i * cols : (i + 1) * cols],
)
def cross_cancel(a, i, b, j):
"""Does the row op row[i] = a*row[i] - b*row[j]"""
q = (j - i) * cols
for p in range(i * cols, (i + 1) * cols):
mat[p] = isimp(a * mat[p] - b * mat[p + q])
def isimp(x):
return x
piv_row, piv_col = 0, 0
pivot_cols = []
swaps = []
# use a fraction free method to zero above and below each pivot
while piv_col < cols and piv_row < rows:
(
pivot_offset,
pivot_val,
assumed_nonzero,
newly_determined,
) = _find_reasonable_pivot(get_col(piv_col)[piv_row:], iszerofunc, simpfunc)
# _find_reasonable_pivot may have simplified some things
# in the process. Let's not let them go to waste
for offset, val in newly_determined:
offset += piv_row
mat[offset * cols + piv_col] = val
if pivot_offset is None:
piv_col += 1
continue
pivot_cols.append(piv_col)
if pivot_offset != 0:
row_swap(piv_row, pivot_offset + piv_row)
swaps.append((piv_row, pivot_offset + piv_row))
# if we aren't normalizing last, we normalize
# before we zero the other rows
if normalize_last is False:
i, j = piv_row, piv_col
mat[i * cols + j] = S.One
for p in range(i * cols + j + 1, (i + 1) * cols):
mat[p] = isimp(mat[p] / pivot_val)
# after normalizing, the pivot value is 1
pivot_val = S.One
# zero above and below the pivot
for row in range(rows):
# don't zero our current row
if row == piv_row:
continue
# don't zero above the pivot unless we're told.
if zero_above is False and row < piv_row:
continue
# if we're already a zero, don't do anything
val = mat[row * cols + piv_col]
if iszerofunc(val):
continue
cross_cancel(pivot_val, row, val, piv_row)
piv_row += 1
# normalize each row
if normalize_last is True and normalize is True:
for piv_i, piv_j in enumerate(pivot_cols):
pivot_val = mat[piv_i * cols + piv_j]
mat[piv_i * cols + piv_j] = S.One
for p in range(piv_i * cols + piv_j + 1, (piv_i + 1) * cols):
mat[p] = isimp(mat[p] / pivot_val)
return mat, tuple(pivot_cols), tuple(swaps)
# This functions is a candidate for caching
# if it gets implemented for matrices.
[docs]
def row_reduce(
M,
iszerofunc=lambda x: x.is_zero,
simpfunc=lambda x: Rational(x).limit_denominator(1000),
normalize_last=True,
normalize=True,
zero_above=True,
):
mat, pivot_cols, swaps = _row_reduce_list(
list(M),
M.rows,
M.cols,
iszerofunc,
simpfunc,
normalize_last=normalize_last,
normalize=normalize,
zero_above=zero_above,
)
return M._new(M.rows, M.cols, mat), pivot_cols, swaps
[docs]
def independent_row_indices(
m,
iszerofunc=lambda x: x.is_zero,
simpfunc=lambda x: Rational(x).limit_denominator(1000),
):
_, pivots, swaps = row_reduce(m, iszerofunc, simpfunc)
indices = np.array(range(len(m)))
for swap in np.array(swaps):
indices[swap] = indices[swap[::-1]]
return indices[: len(pivots)]