python 配置子图轴与通用配置

mzaanser  于 2023-05-05  发布在  Python
关注(0)|答案(1)|浏览(110)

我想用相同的轴配置配置所有子图。下面的代码工作,但在子图增加,我们需要添加更多的elif来处理新的轴配置。如何简化这段代码来处理以变量名作为参数子图?

xaxis,yaxis = get_xyaxis()
for i,yname in enumerate(colnames):        
    trace1 = go.Scatter(
        x=df[xname],
        y=df[yname],           
        name=yname)      
    fig.add_trace(
        trace1,
        row=i+1,
        col=1
    )
    if i == 0:
        fig.update_layout(
            xaxis1=xaxis,
            yaxis1=yaxis,
        )
    elif i == 1:
        fig.update_layout(
            xaxis2=xaxis,
            yaxis2=yaxis,
        )
    elif i == 2:
        fig.update_layout(
            xaxis3=xaxis,
            yaxis3=yaxis,
        )

我试了"xaxis%d"%(i + 1)=xaxis,但它不工作!
完整代码:
import环import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_line(df,pngname):    
    fontsize = 10
    title = "demo"
    xlabel = "KeyPoint"
    ylabel = "Duration(secs)"
    xname = df.columns[0]
    colnames = df.columns[1:]
    
    n = len(colnames)

O

fig = make_subplots(
        rows=n, cols=1,
        shared_xaxes=True,
        vertical_spacing = 0.02,
    )

    xaxis,yaxis = get_xyaxis()
    for i,yname in enumerate(colnames):        
        trace1 = go.Scatter(
            x=df[xname],
            y=df[yname],
            text=df[yname],
            textposition='top center',
            mode='lines+markers',            
            marker=dict(
                size=10,
                line=dict(width=0,color='DarkSlateGrey')),            
            name=yname)
        
        fig.add_trace(
            trace1,
            row=i+1,
            col=1
        )
        # TODO
        if i == 0:
            fig.update_layout(
                xaxis1=xaxis,
                yaxis1=yaxis,
            )
        elif i == 1:
            fig.update_layout(
                xaxis2=xaxis,
                yaxis2=yaxis,
            )
        elif i == 2:
            fig.update_layout(
                xaxis3=xaxis,
                yaxis3=yaxis,
            )
                            
    xpading=.05
    fig.update_layout(
        margin=dict(l=20,t=40,r=10,b=40),
        plot_bgcolor='#ffffff',#'rgb(12,163,135)',
        paper_bgcolor='#ffffff',        
        title=title,
        title_x=0.5,
        showlegend=True,
        legend=dict(x=.02,y=1.05),        
        barmode='group',
        bargap=0.05,
        bargroupgap=0.0,
        font=dict(
            family="Courier New, monospace",
            size=fontsize,
            color="black"
        ),
    )
    fig.show()
    return

def get_xyaxis():
    xaxis=dict(
        title_standoff=1,
        tickangle=-15,            
        showline=True,
        linecolor='black',
        color='black',
        linewidth=.5,
        ticks='outside',
        showgrid=True,
        gridcolor='grey',
        gridwidth=.5,
        griddash='solid',#'dot',            
    )
    yaxis=dict(
        title_standoff=1,
        showline=True,
        linecolor='black',
        color='black',
        linewidth=.5,            
        showgrid=True,
        gridcolor='grey',
        gridwidth=.5,
        griddash='solid',#'dot',
        zeroline=True,
        zerolinecolor='grey',
        zerolinewidth=.5,
        showticklabels=True,
    )        
    return [xaxis,yaxis]

def main():
    data = [
        ['AAA',1,2,3],
        ['BBB',3,2,3],
        ['CCC',2,1,2],
        ['DDD',4,2,3],
        ]

    df = pd.DataFrame(data,columns=['name','v1','v2','v3'])
    print(df)
    plot_line(df,"./demo.png")
    return
    
main()

输出:x1c 0d1x

6rqinv9w

6rqinv9w1#

由于所有代码已张贴到您的问题,我将删除我的评论,因为它没有解决您的问题,并回答它代替。使用“轴更新”可以将轴设置应用于所有子图。

import re
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_line(df,pngname):    
    fontsize = 10
    title = "demo"
    xlabel = "KeyPoint"
    ylabel = "Duration(secs)"
    xname = df.columns[0]
    colnames = df.columns[1:]
    
    n = len(colnames)
    fig = make_subplots(
        rows=n, cols=1,
        shared_xaxes=True,
        vertical_spacing = 0.02,
    )

    xaxis,yaxis = get_xyaxis()
    for i,yname in enumerate(colnames): 
        trace1 = go.Scatter(
            x=df[xname],
            y=df[yname],
            text=df[yname],
            textposition='top center',
            mode='lines+markers',            
            marker=dict(
                size=10,
                line=dict(width=0,color='DarkSlateGrey')),            
            name=yname)
        
        fig.add_trace(
            trace1,
            row=i+1,
            col=1
        )
        fig.update_xaxes(xaxis)
        fig.update_yaxes(yaxis)
        # TODO
        # if i == 0:
        #     fig.update_layout(
        #         xaxis1=xaxis,
        #         yaxis1=yaxis,
        #     )
        # elif i == 1:
        #     fig.update_layout(
        #         xaxis2=xaxis,
        #         yaxis2=yaxis,
        #     )
        # elif i == 2:
        #     fig.update_layout(
        #         xaxis3=xaxis,
        #         yaxis3=yaxis,
        #     )
                            
    xpading=.05
    fig.update_layout(
        margin=dict(l=20,t=40,r=10,b=40),
        plot_bgcolor='#ffffff',#'rgb(12,163,135)',
        paper_bgcolor='#ffffff',        
        title=title,
        title_x=0.5,
        showlegend=True,
        legend=dict(x=.02,y=1.05),        
        barmode='group',
        bargap=0.05,
        bargroupgap=0.0,
        font=dict(
            family="Courier New, monospace",
            size=fontsize,
            color="black"
        ),
    )
    fig.show()
    return

def get_xyaxis():
    xaxis=dict(
        title_standoff=1,
        tickangle=-15,            
        showline=True,
        linecolor='black',
        color='black',
        linewidth=.5,
        ticks='outside',
        showgrid=True,
        gridcolor='grey',
        gridwidth=.5,
        griddash='solid',#'dot',            
    )
    yaxis=dict(
        title_standoff=1,
        showline=True,
        linecolor='black',
        color='black',
        linewidth=.5,            
        showgrid=True,
        gridcolor='grey',
        gridwidth=.5,
        griddash='solid',#'dot',
        zeroline=True,
        zerolinecolor='grey',
        zerolinewidth=.5,
        showticklabels=True,
    ) 
    return [xaxis,yaxis]

def main():
    data = [
        ['AAA',1,2,3],
        ['BBB',3,2,3],
        ['CCC',2,1,2],
        ['DDD',4,2,3],
        ]

    df = pd.DataFrame(data,columns=['name','v1','v2','v3'])
    print(df)
    plot_line(df,"./demo.png")
    return
    
main()

更新:尝试添加注解到子图,但只有最后一个工作
import环import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_line(df,pngname):    
    fontsize = 10
    title = "demo"
    xlabel = "KeyPoint"
    ylabel = "Duration(secs)"
    xname = df.columns[0]
    colnames = df.columns[1:]
    
    n = len(colnames)
    
    fig = make_subplots(
        rows=n, cols=1,
        shared_xaxes=True,
        vertical_spacing = 0.02,
    )

    xaxis,yaxis = get_xyaxis()
    for i,yname in enumerate(colnames):        
        trace1 = go.Scatter(
            x=df[xname],
            y=df[yname],
            text=df[yname],
            textposition='top center',
            mode='lines+markers',            
            marker=dict(
                size=10,
                line=dict(width=0,color='DarkSlateGrey')),            
            name=yname)
        
        fig.add_trace(
            trace1,
            row=i+1,
            col=1
        )
        fig.update_xaxes(xaxis)
        fig.update_yaxes(yaxis)
        add_anns(fig,df,xname,yname,i)
                
                            
    xpading=.05
    fig.update_layout(
        margin=dict(l=20,t=40,r=10,b=40),
        plot_bgcolor='#ffffff',#'rgb(12,163,135)',
        paper_bgcolor='#ffffff',        
        title=title,
        title_x=0.5,
        showlegend=True,
        legend=dict(x=.02,y=1.05),        
        barmode='group',
        bargap=0.05,
        bargroupgap=0.0,
        font=dict(
            family="Courier New, monospace",
            size=fontsize,
            color="black"
        ),
    )
    fig.show()
    return

def get_xyaxis():
    xaxis=dict(
        title_standoff=1,
        tickangle=-15,            
        showline=True,
        linecolor='black',
        color='black',
        linewidth=.5,
        ticks='outside',
        showgrid=True,
        gridcolor='grey',
        gridwidth=.5,
        griddash='solid',#'dot',            
    )
    yaxis=dict(
        title_standoff=1,
        showline=True,
        linecolor='black',
        color='black',
        linewidth=.5,            
        showgrid=True,
        gridcolor='grey',
        gridwidth=.5,
        griddash='solid',#'dot',
        zeroline=True,
        zerolinecolor='grey',
        zerolinewidth=.5,
        showticklabels=True,
    )        
    return [xaxis,yaxis]

def add_anns(fig,df,xname,yname,i):
    prev = df.loc[0]
    for idx, row in df.iterrows():
        dy = row[yname] - prev[yname]
        x0 = row[xname]
        y0 = row[yname]
        x1 = row[xname]
        y1 = prev[yname]
        xref = "x%d"%(i+1)
        yref = "y%d"%(i+1)
        print("----",dy)
        if abs(dy) >= 2:
            ans = add_vline(fig,x0,y0,y1,xref,yref,"%.1f"%(dy))
    prev = row    
    return

def add_vline(fig,x0,y0,y1,xref,yref,text=None):
    dw = 10 # pixels
    if text == None:
        text = "%.1f"%(y1-y0)
    anns = []
    fig.add_annotation(
             #vertical1
                x=x0,y=y0,ax=x0,ay=y1,
                xref=xref,yref=yref,axref=xref,ayref=yref,
                showarrow=True,text='',
                arrowhead=2,arrowside='start+end',arrowsize=2,arrowwidth=.5,arrowcolor='black',
            )
    fig.add_annotation(# start     
                x=x0,y=y0,ax=-dw,ay=y0,
                xref=xref,yref=yref,axref='pixel',ayref=yref,
                showarrow=True,text='',arrowwidth=.5,arrowcolor='black',                             
            )
    fig.add_annotation(    
                x=x0,y=y0,ax=dw,ay=y0,
                xref=xref,yref=yref,axref='pixel',ayref=yref,
                showarrow=True,text='',arrowwidth=.5,arrowcolor='black',                             
            )
    fig.add_annotation( # end
                x=x0,y=y1,ax=-dw,ay=y1,
                xref=xref,yref=yref,axref='pixel',ayref=yref,
                showarrow=True,text='',arrowwidth=.5,arrowcolor='black',                             
            )
    fig.add_annotation(    
                x=x0,y=y1,ax=dw,ay=y1,
                xref=xref,yref=yref,axref='pixel',ayref=yref,
                showarrow=True,text='',arrowwidth=.5,arrowcolor='black',                             
            )
    fig.add_annotation(# text label
                x=x0, y=(y0+y1)/2,
                xref=xref,yref=yref,
                text=text,textangle=0,font=dict(color='black',size=14),
                bgcolor='white',
                showarrow=False,arrowhead=1,arrowwidth=2,
            )
    return

def main():
    data = [
        ['AAA',1,2,3],
        ['BBB',3,2,3],
        ['CCC',2,1,2],
        ['DDD',4,2,3],
        ]

    df = pd.DataFrame(data,columns=['name','v1','v2','v3'])
    print(df)
    plot_line(df,"./demo.png")
    return
    
main()

输出:

相关问题