在Python中使用sympy解析和计算具有不等式的方程,导致SyntaxError

d7v8vwbk  于 2023-06-25  发布在  Python
关注(0)|答案(2)|浏览(125)
import matplotlib.pyplot as plt
import numpy as np
from sympy import symbols, Eq, solve, lambdify, parse_expr
from sympy.core.relational import Relational

# Define symbolic variables
x1, x2 = symbols('x1 x2')

# List of equations or inequalities that serve as constraints in the two-dimensional linear system
sel = ["2 * X1 + 1 * X2 >= 2", "3 * X1 + 4 * X2 <= 12", "4 * X1 + 3 * X2 <= 12"]
sel = ["2 * X1 + 1 * X2 >= 2", "3 * X1 + 4 * X2 == 12", "4 * X1 + 3 * X2 <= 12"]
sel = ["2 * X1 + 1 * X2 >= 2", "3 * X1 + 4 * X2 <= 12", "4 * X1 + 3 * X2 <= 12", "X1 + X2 - 4 == 0"]

# Linear constraints (inequalities)
linear_constraints = []
# Linear constraints (equations)
equations = []

# Separate the inequalities from the equations
for equation in sel:
    parsed_eq = equation.replace('X1', 'x1').replace('X2', 'x2')
    if '>=' in parsed_eq or '<=' in parsed_eq:
        linear_constraints.append(Relational(parse_expr(parsed_eq), 0))
    else:
        equations.append(parse_expr(parsed_eq))

# Define the range of values for the x-axis
x = np.linspace(-10, 10, 100)

# List to store the intersection points
intersections = []

# Solve the inequalities and find the intersections
for constraint in linear_constraints:
    if isinstance(constraint, Relational):
        inequality = constraint.rel_op

        if inequality == '>=':
            r = solve(Eq(constraint.lhs - constraint.rhs, 0), x2)[0]
        elif inequality == '<=':
            r = solve(Eq(constraint.lhs - constraint.rhs, 0), x2)[0]
        else:
            raise ValueError("Invalid inequality sign")

        intersections.append((0, r.subs(x1, 0)))
        intersections.append((solve(Eq(constraint.lhs - constraint.rhs, 0), x1)[0].subs(x2, 0), 0))
    else:
        r = solve(Eq(constraint, 0), x2)[0]
        intersections.append((0, r.subs(x1, 0)))
        intersections.append((solve(Eq(constraint, 0), x1)[0].subs(x2, 0), 0))

# Solve the equations and find the intersections
for equation in equations:
    r = solve(equation, (x1, x2))
    for solution in r:
        intersections.append((solution[x1], solution[x2]))

# Filter the points that are in the first quadrant
vertices = [point for point in intersections if point[0] >= 0 and point[1] >= 0]

# Print the vertices
for i, vertex in enumerate(vertices):
    print(f"Vertex {i+1}: {vertex}")

# Retrieve the x and y coordinates of the vertices
x_coords = [vertex[0] for vertex in vertices]
y_coords = [vertex[1] for vertex in vertices]

# Plot the vertices
plt.plot(x_coords, y_coords, 'ro')

# Traverse the list of equations, calculate the constraints, and plot them
for equation in sel:
    # Parse the equation and obtain the inequality
    parsed_eq = equation.replace('X1', 'x1').replace('X2', 'x2')
    inequality = parsed_eq.split()[1]

    # Solve the equation to obtain the linear constraint
    if inequality == '>=':
        r = solve(Eq(parse_expr(parsed_eq.replace('=', '-'))), x2)[0]
    elif inequality == '<=':
        r = solve(Eq(parse_expr(parsed_eq.replace('=', '-'))), x2)[0]
    else:
        raise ValueError("Invalid inequality sign")

    # Create the linear constraint function
    linear_constraint = lambdify(x1, r, 'numpy')

    # Evaluate the linear constraint in the range of x
    y = linear_constraint(x)

    # Plot the linear constraint
    plt.plot(x, y, label=equation)

# Adjust the plot limits
plt.xlim(-10, 10)
plt.ylim(-10, 10)

# Each constraint is plotted using plt.plot() with a specific label obtained from the list of equations.
# The plt.legend() function takes those labels and displays them in the plot as a legend that identifies each constraint.
plt.legend()

# Show the axis labels
plt.xlabel('x1')
plt.ylabel('x2')

# Set ticks in increments of 1 unit
plt.xticks(np.arange(-10, 11, 1))
plt.yticks(np.arange(-10, 11, 1))

plt.title('Constraint Graph')
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.grid(True, linestyle='--', alpha=0.7)

# Show the resulting plot
plt.show()

我试图使用Python中的sympy库解析和计算具有不等式的方程。然而,当我在sel列表中包含等式沿着不等式时,我会遇到SyntaxError。我收到的错误消息如下:

Traceback (most recent call last):
  File "plot_sel.py", line 26, in <module>
    equations.append(Eq(parse_expr(parsed_eq), 0))
    expr = eval(
  File "<string>", line 1
    Symbol ('x1' )+Symbol ('x2' )-Integer (4 )=Integer (0 )
                                              ^
SyntaxError: invalid syntax

我认为这个错误是由于使用eval来解析和计算expr = eval(equation.replace('X1', 'x1').replace('X2', 'x2'))行中的方程而引起的。
为什么会发生这种错误?又该如何解决呢?
如果我用==替换=,就会得到这个错误:

File "plot_sel.py", line 49, in <module>
    r = solve(Eq(constraint, 0), x2)[0]
IndexError: list index out of range
wj8zmpe1

wj8zmpe11#

使用sympy的Eq函数时,使用单个=符号不起作用,您需要使用==。最重要的是,Eq函数有两个参数,它们是==两边的表达式。因此,您需要在==处拆分字符串,如下所示。我还为parse_expr使用了local_dict参数,所以不需要使用字符串替换方法。

for equation in sel:
    variables = {"X1":x1, "X2":x2}
    # inequality code
    else:
        split_eq = equation.split("==")
        parsed_eq = [parse_expr(eq, local_dict=variables) for eq in split_eq] 
        equations.append(Eq(*parsed_eq))

另外,我想你并不打算评估你的不平等性。你可能想把它们当作不平等。与==一样,你需要在不等式符号上拆分方程,并将每个部分与不等式符号一起单独传递。

for equation in sel:
    variables = {"X1":x1, "X2":x2}
    if ">=" in equation:
        split_eq = equation.split(">=")
        parsed_eq = [parse_expr(eq, local_dict=variables) for eq in split_eq] 
        linear_constraints.append(Relational(*parsed_eq, ">="))
    elif "<=" in equation:
        split_eq = equation.split("<=")
        parsed_eq = [parse_expr(eq, local_dict=variables) for eq in split_eq] 
        linear_constraints.append(Relational(*parsed_eq, "<="))
    elif "==" in equation:
        split_eq = equation.split("==")
        parsed_eq = [parse_expr(eq, local_dict=variables) for eq in split_eq] 
        equations.append(Eq(*parsed_eq))
    else:
        raise ValueError(f"Unknown symbol in equation: {equation}.")

现在,您可以直接对线性约束和方程使用solve函数。

zyfwsgd6

zyfwsgd62#

在定义了x1,x2符号的ipython会话中测试这些sel列表。

sel = [2 * x1 + 1 * x2 >= 2, 3 * x1 + 4 * x2 = 12, 4 * x1 + 3 * x2 <= 12]
  Cell In[10], line 1
    sel = [2 * x1 + 1 * x2 >= 2, 3 * x1 + 4 * x2 = 12, 4 * x1 + 3 * x2 <= 12]
                                 ^
SyntaxError: cannot assign to expression here. Maybe you meant '==' instead of '='?

比您的版本更清楚,这告诉我们=的问题。
但是如果我使用==

sel = [2 * x1 + 1 * x2 >= 2, 3 * x1 + 4 * x2 == 12, 4 * x1 + 3 * x2 <= 12]

type(sel[0])
Out[12]: sympy.core.relational.GreaterThan

type(sel[1])
Out[13]: bool

type(sel[2])
Out[14]: sympy.core.relational.LessThan

sel
Out[15]: [2*x1 + x2 >= 2, False, 4*x1 + 3*x2 <= 12]

第二个等式没有给予关系。但是使用sympy Eq代替:

In [21]: sel = [2 * x1 + 1 * x2 >= 2, sp.Eq(3 * x1 + 4 * x2, 12), 4 * x1 + 3 * x2 <= 12]

In [22]: sel
Out[22]: [2*x1 + x2 >= 2, Eq(3*x1 + 4*x2, 12), 4*x1 + 3*x2 <= 12]

In [23]: type(sel[1])
Out[23]: sympy.core.relational.Equality

你在后面的代码中确实使用Eq

相关问题