diff --git a/CHANGELOG.md b/CHANGELOG.md index 7464e9a8f..f89a4a6cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,9 @@ ### Added ### Fixed ### Changed +- Move magic methods (`__radd__`, `__sub__`, `__rsub__`, `__rmul__`, `__richcmp__`, `__neg__`, and `__rtruediv__`) to `ExprLike` base class - Speed up `Expr.__add__` and `Expr.__iadd__` via the C-level API +- Speed up `SumExpr.__neg__`, `ProdExpr.__neg__` and `Constant.__neg__` via C-level API ### Removed ## 6.2.1 - 2026.05.16 diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index bb25fdb29..b1371e880 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -257,6 +257,27 @@ cdef class ExprLike: return NotImplemented + def __radd__(self, other, /): + return self + other + + def __sub__(self, other, /): + return self + (-other) + + def __rsub__(self, other, /): + return (-self) + other + + def __rmul__(self, other, /): + return self * other + + def __rtruediv__(self, other, /) -> GenExpr: + return buildGenExprObj(other) / self + + def __richcmp__(self, other, int op): + return _expr_richcmp(self, other, op) + + def __neg__(self, /) -> Union[Expr, GenExpr]: + return self * -1.0 + def __abs__(self) -> GenExpr: return UnaryExpr(Operator.fabs, buildGenExprObj(self)) @@ -358,11 +379,10 @@ cdef class Expr(ExprLike): return 1.0 / other * self return buildGenExprObj(self) / other - def __rtruediv__(self, other): - ''' other / self ''' + def __rtruediv__(self, other, /) -> GenExpr: if not _is_expr_compatible(other): return NotImplemented - return buildGenExprObj(other) / self + return super().__rtruediv__(other) def __pow__(self, other, modulo): if float(other).is_integer() and other >= 0: @@ -387,25 +407,6 @@ cdef class Expr(ExprLike): raise ValueError("Base of a**x must be positive, as expression is reformulated to scip.exp(x * scip.log(a)); got %g" % base) return (self * Constant(base).log()).exp() - def __neg__(self): - return Expr({v:-c for v,c in self.terms.items()}) - - def __sub__(self, other): - return self + (-other) - - def __radd__(self, other): - return self.__add__(other) - - def __rmul__(self, other): - return self.__mul__(other) - - def __rsub__(self, other): - return -1.0 * self + other - - def __richcmp__(self, other, int op): - '''turn it into a constraint''' - return _expr_richcmp(self, other, op) - def normalize(self): '''remove terms with coefficient of 0''' self.terms = {t:c for (t,c) in self.terms.items() if c != 0.0} @@ -464,7 +465,6 @@ cdef class ExprCons: if not self._rhs is None: self._rhs -= c - def __richcmp__(self, other, op): '''turn it into a constraint''' if not _is_number(other): @@ -690,30 +690,10 @@ cdef class GenExpr(ExprLike): raise ZeroDivisionError("cannot divide by 0") return self * divisor**(-1) - def __rtruediv__(self, other): - ''' other / self ''' + def __rtruediv__(self, other, /) -> GenExpr: if not _is_genexpr_compatible(other): return NotImplemented - return buildGenExprObj(other) / self - - def __neg__(self): - return -1.0 * self - - def __sub__(self, other): - return self + (-other) - - def __radd__(self, other): - return self.__add__(other) - - def __rmul__(self, other): - return self.__mul__(other) - - def __rsub__(self, other): - return -1.0 * self + other - - def __richcmp__(self, other, int op): - '''turn it into a constraint''' - return _expr_richcmp(self, other, op) + return super().__rtruediv__(other) def degree(self): '''Note: none of these expressions should be polynomial''' @@ -749,6 +729,23 @@ cdef class SumExpr(GenExpr): self.coefs = [] self.children = [] self._op = Operator.add + + def __neg__(self) -> SumExpr: + cdef int i = 0, n = len(self.coefs) + cdef list coefs = [0.0] * n + cdef double[:] dest_view = coefs + cdef double[:] src_view = self.coefs + + for i in range(n): + dest_view[i] = -src_view[i] + + cdef SumExpr res = SumExpr.__new__(SumExpr) + res.coefs = coefs + res.children = self.children.copy() + res.constant = -self.constant + res._op = Operator.add + return res + def __repr__(self): return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")" @@ -756,7 +753,7 @@ cdef class SumExpr(GenExpr): cdef double res = self.constant cdef int i = 0, n = len(self.children) cdef list children = self.children - cdef list coefs = self.coefs + cdef double[:] coefs = self.coefs for i in range(n): res += coefs[i] * (children[i])._evaluate(sol) return res @@ -772,6 +769,11 @@ cdef class ProdExpr(GenExpr): self.children = [] self._op = Operator.prod + def __neg__(self) -> ProdExpr: + cdef ProdExpr res = self.copy(copy=True) + res.constant = -res.constant + return res + def __repr__(self): return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")" @@ -841,11 +843,16 @@ cdef class UnaryExpr(GenExpr): # class for constant expressions cdef class Constant(GenExpr): + cdef public number + def __init__(self,number): self.number = number self._op = Operator.const + def __neg__(self) -> Constant: + return Constant(-self.number) + def __repr__(self): return str(self.number) diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index ee82c00d9..d22dcffc6 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -331,6 +331,12 @@ class ExprLike: *args: Incomplete, **kwargs: Incomplete, ) -> Incomplete: ... + def __radd__(self, other: object, /) -> Incomplete: ... + def __sub__(self, other: object, /) -> Incomplete: ... + def __rsub__(self, other: object, /) -> Incomplete: ... + def __rmul__(self, other: object, /) -> Incomplete: ... + def __rtruediv__(self, other: object, /) -> GenExpr: ... + def __neg__(self, /) -> Union[Expr, GenExpr]: ... def __abs__(self) -> GenExpr: ... def exp(self) -> GenExpr: ... def log(self) -> GenExpr: ... @@ -344,7 +350,6 @@ class Expr(ExprLike): def __init__(self, terms: Incomplete = ...) -> None: ... def degree(self) -> Incomplete: ... def normalize(self) -> Incomplete: ... - def __abs__(self) -> GenExpr: ... def __add__(self, other: Incomplete, /) -> Incomplete: ... def __eq__(self, other: object, /) -> bool: ... def __ge__(self, other: object, /) -> bool: ... @@ -356,14 +361,8 @@ class Expr(ExprLike): def __lt__(self, other: object, /) -> bool: ... def __mul__(self, other: Incomplete, /) -> Incomplete: ... def __ne__(self, other: object, /) -> bool: ... - def __neg__(self) -> Incomplete: ... def __pow__(self, other: Incomplete, modulo: Incomplete = ..., /) -> Incomplete: ... - def __radd__(self, other: Incomplete, /) -> Incomplete: ... - def __rmul__(self, other: Incomplete, /) -> Incomplete: ... def __rpow__(self, other: Incomplete, /) -> Incomplete: ... - def __rsub__(self, other: Incomplete, /) -> Incomplete: ... - def __rtruediv__(self, other: Incomplete, /) -> Incomplete: ... - def __sub__(self, other: Incomplete, /) -> Incomplete: ... def __truediv__(self, other: Incomplete, /) -> Incomplete: ... @disjoint_base @@ -391,23 +390,23 @@ class GenExpr(ExprLike): def degree(self) -> Incomplete: ... def getOp(self) -> Incomplete: ... def __abs__(self) -> GenExpr: ... - def __add__(self, other: Incomplete, /) -> Incomplete: ... - def __eq__(self, other: object, /) -> bool: ... - def __ge__(self, other: object, /) -> bool: ... - def __gt__(self, other: object, /) -> bool: ... - def __le__(self, other: object, /) -> bool: ... - def __lt__(self, other: object, /) -> bool: ... - def __mul__(self, other: Incomplete, /) -> Incomplete: ... - def __ne__(self, other: object, /) -> bool: ... + def __add__(self, other: Incomplete) -> Incomplete: ... + def __eq__(self, other: object) -> bool: ... + def __ge__(self, other: object) -> bool: ... + def __gt__(self, other: object) -> bool: ... + def __le__(self, other: object) -> bool: ... + def __lt__(self, other: object) -> bool: ... + def __mul__(self, other: Incomplete) -> Incomplete: ... + def __ne__(self, other: object) -> bool: ... def __neg__(self) -> Incomplete: ... - def __pow__(self, other: Incomplete, modulo: Incomplete = ..., /) -> Incomplete: ... - def __radd__(self, other: Incomplete, /) -> Incomplete: ... - def __rmul__(self, other: Incomplete, /) -> Incomplete: ... - def __rpow__(self, other: Incomplete, /) -> Incomplete: ... - def __rsub__(self, other: Incomplete, /) -> Incomplete: ... - def __rtruediv__(self, other: Incomplete, /) -> Incomplete: ... - def __sub__(self, other: Incomplete, /) -> Incomplete: ... - def __truediv__(self, other: Incomplete, /) -> Incomplete: ... + def __pow__(self, other: Incomplete, modulo: Incomplete = ...) -> Incomplete: ... + def __radd__(self, other: Incomplete) -> Incomplete: ... + def __rmul__(self, other: Incomplete) -> Incomplete: ... + def __rpow__(self, other: Incomplete) -> Incomplete: ... + def __rsub__(self, other: Incomplete) -> Incomplete: ... + def __rtruediv__(self, other: Incomplete) -> Incomplete: ... + def __sub__(self, other: Incomplete) -> Incomplete: ... + def __truediv__(self, other: Incomplete) -> Incomplete: ... @disjoint_base class Heur: diff --git a/tests/test_expr.py b/tests/test_expr.py index f35096f73..9650c725b 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -1,10 +1,17 @@ -import math - import numpy as np import pytest from pyscipopt import Model, cos, exp, log, quickprod, sin, sqrt -from pyscipopt.scip import CONST, Expr, ExprCons, GenExpr, MatrixGenExpr +from pyscipopt.scip import ( + CONST, + Constant, + Expr, + ExprCons, + GenExpr, + MatrixGenExpr, + ProdExpr, + SumExpr, +) @pytest.fixture(scope="module") @@ -222,6 +229,36 @@ def test_getVal_with_GenExpr(): m.getVal(1 / z) +def test_neg(): + m = Model() + x = m.addVar(name="x") + + expr = (x + 1) ** 3 + neg_expr = -expr + assert isinstance(expr, Expr) + assert isinstance(neg_expr, Expr) + assert ( + str(neg_expr) + == "Expr({Term(x, x, x): -1.0, Term(x, x): -3.0, Term(x): -3.0, Term(): -1.0})" + ) + + base = sqrt(x) + expr = base * -1 + neg_expr = -expr + assert isinstance(expr, ProdExpr) + assert isinstance(neg_expr, ProdExpr) + assert str(neg_expr) == "prod(1.0,sqrt(sum(0.0,prod(1.0,x))))" + + expr = base + x - 1 + neg_expr = -expr + assert isinstance(expr, SumExpr) + assert isinstance(neg_expr, SumExpr) + assert str(neg_expr) == "sum(1.0,sqrt(sum(0.0,prod(1.0,x))),prod(1.0,x))" + assert list(neg_expr.coefs) == [-1, -1] + + assert str(-Constant(3.0)) == "-3.0" + + def test_unary_ufunc(model): m, x, y, z = model