numpy 在循环中生成的堆叠图

9nvpjoqh  于 2023-04-30  发布在  其他
关注(0)|答案(1)|浏览(169)

我正在运行一个if循环,并希望将结果图堆叠在一个网格中。这是我的示例代码,它生成两个随机变量,并在两个条件下绘制第三个随机变量:

import numpy as np
import matplotlib.pyplot as plt

# Set seed for reproducibility
np.random.seed(42)

# Generate 10 realizations of the uniform distribution between -1 and 1 for x and y
x = np.random.uniform(low=-1, high=1, size=10)
y = np.random.uniform(low=-1, high=1, size=10)

# Create empty list to store valid plots
valid_plots = []

# Initialize an empty list to store the current row of plots
current_row = []

# Loop through each realization of x and y
for i in range(len(x)):
    # Check if both x and y are positive
    if x[i] > 0 and y[i] < 0:
        # Generate 100 values of z
        z = np.linspace(-1, 1, 100)
        # Compute the function z = xy*z^2
        z_func = x[i] * y[i] * z*z
        # Plot the function
        fig, ax = plt.subplots()
        ax.plot(z, z_func)
        # If there are now two plots in the current row, append the row to valid_plots and start a new row
        if len(current_row) % 2 == 1:
            valid_plots.append(current_row)
            current_row = []
        # Append the current plot to the current row
        current_row.append(ax)
# If there is only one plot in the last row, append the row to valid_plots
if len(current_row) > 0 and len(current_row) % 2 == 1:
    current_row.append(plt.gca())
    valid_plots.append(current_row)

# Create a figure with subplots for each valid plot
num_rows = len(valid_plots)
fig, axes = plt.subplots(num_rows, 2, figsize=(12, 4 * num_rows))
for i, row in enumerate(valid_plots):
    for j, ax in enumerate(row):
        # Check if the plot has any lines before accessing ax.lines[0]
        if len(ax.lines) > 0:
            axes[i, j].plot(ax.lines[0].get_xdata(), ax.lines[0].get_ydata())
plt.show()

输出的问题在于它生成了两个空图,然后开始垂直堆叠:

能帮我吗我也对实现这一结果的更有效的方法感兴趣。

agyaoht7

agyaoht71#

根据我所理解的,你得到了一个随机数的图,你只想画那些对于每对x和y都有x[i]>0 and y[i]<0的图。你也不想有空白的地方。
在您的实现中,主要问题是您首先创建所有图,然后尝试创建Nx 2图。请参考下面的代码。我首先检查了所有的对,其中x[i]>0 and y[i]<0和收集的索引单独在一个数组称为valid_ids。……还没有策划任何事情。
然后,由于我们知道len(valid_ids)需要多少个子图,我在一个for循环中绘制它。请注意,math.ceil()将给予我2个图的数量上限,所以我们得到4个图的2行或5个图的3行。这样,我们最终将所有z作为子图绘制在一个镜头中。最后,如果有奇数个子图,则可以使用fig.delaxes()删除多余的图。希望这就是你要找的。..

import numpy as np
import matplotlib.pyplot as plt

# Set seed for reproducibility
np.random.seed(42)

# Generate 10 realizations of the uniform distribution between -1 and 1 for x and y
x = np.random.uniform(low=-1, high=1, size=10)
y = np.random.uniform(low=-1, high=1, size=10)

valid_ids=[]
for i, (j, k) in enumerate(zip(x, y)):
    if j > 0 and k < 0:
        valid_ids.append(i)

import math
num_rows=math.ceil(len(valid_ids)/2)

if len(valid_ids) > 0: ## No rows, don't plot anything 
    if len(valid_ids) > 2: ## More than one row, plot as axes[row, col]
        fig, axes = plt.subplots(num_rows, 2, figsize=(12, 4 * num_rows))
        for i in range(len(valid_ids)):
            # Generate 100 values of z
            z = np.linspace(-1, 1, 100)
            # Compute the function z = xy*z^2
            z_func = x[valid_ids[i]] * y[valid_ids[i]] * z*z
            # Plot the function
            axes[int(i/2), i%2].plot(z, z_func)
        # If ODD number of plots, then remove the last blank plot    
        if len(valid_ids)%2 == 1:
            fig.delaxes(axes[num_rows-1,1])
    else: ## Single row, plot as axes[col] 
        fig, axes = plt.subplots(num_rows, 2, figsize=(12, 4))
        for i in range(len(valid_ids)):
            # Generate 100 values of z
            z = np.linspace(-1, 1, 100)
            # Compute the function z = xy*z^2
            z_func = x[valid_ids[i]] * y[valid_ids[i]] * z*z
            # Plot the function
            axes[i].plot(z, z_func)
        # If just one plot, then remove the second one
        if len(valid_ids) == 1:
            fig.delaxes(axes[1])
        
    plt.show()

输出图-在我的情况下有3个图。..

相关问题