为了账号安全,请及时绑定邮箱和手机立即绑定

pythonplotly:不固定数量的痕迹

pythonplotly:不固定数量的痕迹

一只斗牛犬 2023-07-27 10:13:00
我的代码从 .xlsx 文件读取数据,并使用plotly 绘制气泡图。 气泡图 当我知道需要绘制多少条迹线时,任务就很容易了。然而,当由于行数是可变的而跟踪数不固定时,我感到困惑。       1991  1992  1993  1994  1995  1996  1997US       10    14    16    18    20    42    64JAPAN   100    30    70    85    30    42    64CN       50    22    30    65    70    66    60这是我未完成的代码:# Version 2 could read data from .xlsx file.import plotly as pyimport plotly.graph_objs as goimport openpyxlwb = openpyxl.load_workbook(('grape output.xlsx'))     sheet = wb['Sheet1']       row_max = sheet.max_rowcol_max = sheet.max_columnl=[]for row_n in range(row_max-1):    l.append([])    for col_n in range(col_max-1):        l[row_n].append(sheet.cell(row=row_n+2, column=col_n+2).value)trace0 = go.Scatter(    x=[1991, 1992, 1993, 1994, 1995, 1996, 1997],    y=['US', 'US', 'US', 'US', 'US', 'US', 'US'],    mode='markers+text',    marker=dict(        color='rgb(150,204,90)',        size= l[0],        showscale = False,        ),    text=list(map(str, l[0])),         textposition='middle center',   )trace1 = go.Scatter(    x=[1991, 1992, 1993, 1994, 1995, 1996, 1997],    y=['JAPAN', 'JAPAN', 'JAPAN', 'JAPAN', 'JAPAN', 'JAPAN', 'JAPAN'],    mode='markers+text',    marker=dict(        color='rgb(255, 130, 71)',        size=l[1],        showscale=False,    ),    text=list(map(str,l[1])),    textposition='middle center',)trace2 = go.Scatter(    x=[1991, 1992, 1993, 1994, 1995, 1996, 1997],    y=['CN', 'CN', 'CN', 'CN', 'CN', 'CN', 'CN'],    mode='markers+text',    marker=dict(        color='rgb(255, 193, 37)',        size=l[2],        showscale=False,    ),    text=list(map(str,l[2])),    textposition='middle center',)你能教我怎么画它们吗?谢谢
查看完整描述

3 回答

?
暮色呼如

TA贡献1853条经验 获得超9个赞

我认为有一种更灵活的方法可以使用它,plotly.express特别是如果您不想定义颜色。


这个想法是正确地转换数据。


数据

import pandas as pd

df = pd.DataFrame({1991:[10,100,50], 1992:[14,30,22], 1993:[16,70,30], 1994:[18,85,65], 1995:[20,30,70], 1996:[42,42,66], 1997:[64,64,60]})

df.index = ['US','JAPAN','CN']

df = df.T.unstack()\

      .reset_index()\

      .rename(columns={"level_0": "country",

                       "level_1": "year",

                       0: "n"})

print(df)

   country  year    n

0       US  1991   10

1       US  1992   14

2       US  1993   16

3       US  1994   18

4       US  1995   20

5       US  1996   42

6       US  1997   64

7    JAPAN  1991  100

8    JAPAN  1992   30

9    JAPAN  1993   70

10   JAPAN  1994   85

11   JAPAN  1995   30

12   JAPAN  1996   42

13   JAPAN  1997   64

14      CN  1991   50

15      CN  1992   22

16      CN  1993   30

17      CN  1994   65

18      CN  1995   70

19      CN  1996   66

20      CN  1997   60

使用plotly.express

现在您的数据是长格式,您可以使用plotly.express如下


import plotly.express as px

fig = px.scatter(df,

                 x="year",

                 y="country",

                 size="n",

                 color="country",

                 text="n",

                 size_max=50 # you need this otherwise the bubble are too small

                )


fig.update_layout(plot_bgcolor='rgb(10, 10, 10)',  

                  paper_bgcolor='rgb(20, 55, 100)',  

                  font={'size': 15,

                        'family': 'sans-serif',

                        'color': 'rgb(255, 255, 255)'

                       },

                  width=1000,

                  height=500,

                  xaxis=dict(title='Output of grapes per year in selected countries', ),  

                  showlegend=False,

                  margin=dict(l=100, r=100, t=100, b=100),

                  hovermode = False,)

# Uncomment this if you don't wont country as yaxis title

# fig.layout.yaxis.title.text = None

fig.show()

//img4.sycdn.imooc.com/64c1d2d30001855506450306.jpg

查看完整回答
反对 回复 2023-07-27
?
慕妹3242003

TA贡献1824条经验 获得超6个赞

我应该指出,如果您将原始数据附加为文本或可以更轻松地复制和粘贴的内容,那么您的代码将更具可重复性。但是,无论如何,我仍然可以回答您的问题并为您指出正确的方向。


您应该做的是使用循环,并从查看行开始data = [trace0, trace1, trace2]。正如您所注意到的,如果您有 100 个国家/地区而不是 3 个国家/地区,此方法将无法扩展。


相反,您可以data使用列表理解将其创建为列表,并更新每个跟踪中发生更改的部分。trace0除了国家/地区、值和颜色之外, 、trace1、trace2没有太大区别。为了向您展示我的意思,我使用 DataFrame 重新创建了您的数据,然后创建了包含您的国家/地区和颜色的单独列表。


# Version 2 could read data from .xlsx file.

import plotly as py

import plotly.graph_objs as go

import openpyxl


# wb = openpyxl.load_workbook(('grape output.xlsx'))     

# sheet = wb['Sheet1']       

# row_max = sheet.max_row

# col_max = sheet.max_column

# l=[]


# for row_n in range(row_max-1):

#     l.append([])

#     for col_n in range(col_max-1):

#         l[row_n].append(sheet.cell(row=row_n+2, column=col_n+2).value)


import pandas as pd


df = pd.DataFrame({1991:[10,100,50], 1992:[14,30,22], 1993:[16,70,30], 1994:[18,85,65], 1995:[20,30,70], 1996:[42,42,66], 1997:[64,64,60]})

df.index = ['US','JAPAN','CN']

colors = ['rgb(150,204,90)','rgb(255, 130, 71)','rgb(255, 193, 37)']


data = [go.Scatter(

    x=df.columns,

    y=[country]*len(df.columns),

    mode='markers+text',

    marker=dict(

        color=colors[num],

        size= df.loc[country],

        showscale = False,

        ),

    text=list(map(str, df.loc[country])),     

    textposition='middle center',   

    )

    for num, country in enumerate(df.index)

]


layout = go.Layout(plot_bgcolor='rgb(10, 10, 10)',  

                   paper_bgcolor='rgb(20, 55, 100)',  

                   font={               

                       'size': 15,

                       'family': 'sans-serif',

                       'color': 'rgb(255, 255, 255)'  

                   },

                   width=1000,

                   height=500,

                   xaxis=dict(title='Output of grapes per year in US, JAPAN and CN', ),  

                   showlegend=False,

                   margin=dict(l=100, r=100, t=100, b=100),

                   hovermode = False,       

                   )


# data = [trace0, trace1, trace2]

fig = go.Figure(data=data, layout=layout)

fig.show()


# py.offline.init_notebook_mode()

# py.offline.plot(fig, filename='basic-scatter.html')

//img1.sycdn.imooc.com//64c1d2e50001c6c206460285.jpg

如果我随后向 DataFrame 添加一个包含 1991-1997 年值的测试国家/地区,则不需要更改其余代码,气泡图将相应更新。

# I added a test country with datadf = pd.DataFrame({1991:[10,100,50,10], 1992:[14,30,22,20], 1993:[16,70,30,30], 1994:[18,85,65,40], 1995:[20,30,70,50], 1996:[42,42,66,60], 1997:[64,64,60,70]})
df.index = ['US','JAPAN','CN','TEST']
colors = ['rgb(150,204,90)','rgb(255, 130, 71)','rgb(255, 193, 37)','rgb(100, 100, 100)']

//img1.sycdn.imooc.com//64c1d2fd0001ee0706400277.jpg


查看完整回答
反对 回复 2023-07-27
?
慕田峪9158850

TA贡献1794条经验 获得超7个赞

代码已更新到版本 2,可以从 .xlsx 文件读取数据并绘制气泡图。与之前的数据相比,名为“grape output.xlsx”的原始数据添加了新项目:


             1991  1992  1993  1994  1995  1996  1997  1998  1999

         US    10    14    16    18    20    42    64   100    50

      JAPAN   100    30    70    85    30    42    64    98    24

         CN    50    22    30    65    70    66    60    45    45

      INDIA    90    88    35    50    90    60    40    66    76

         UK    40    50    70    50    25    30    22    40    60

这是代码:


# Version 2 

import plotly as py

import plotly.graph_objs as go

import openpyxl

import pandas as pd



wb = openpyxl.load_workbook('grape output.xlsx')

sheet = wb['Sheet1']

row_max = sheet.max_row

col_max = sheet.max_column

first_row_list = []

first_col_list = []

for col_n in range(2, col_max+1):

    first_row_list.append(sheet.cell(row=1, column=col_n).value)

for row_n in range(2,row_max+1):

    first_col_list.append(sheet.cell(row=row_n, column=1).value)


data_all = pd.read_excel('grape output.xlsx')

data = data_all.loc[:,first_row_list]


df = pd.DataFrame(data)

df.index = first_col_list

colors = ['rgb(150,204,90)','rgb(255, 130, 71)','rgb(255, 193, 37)','rgb(180,240,190)','rgb(255, 10, 1)',

          'rgb(25, 19, 3)','rgb(100, 100, 100)','rgb(45,24,200)','rgb(33, 58, 108)','rgb(35, 208, 232)']


data = [go.Scatter(

    x=df.columns,

    y=[country]*len(df.columns),

    mode='markers+text',

    marker=dict(

        color=colors[num],

        size= df.loc[country],

        showscale = False,

        ),

    text=list(map(str, df.loc[country])),

    textposition='middle center',

    )

    for num, country in enumerate(reversed(df.index))

]


layout = go.Layout(plot_bgcolor='rgb(10, 10, 10)',

                   paper_bgcolor='rgb(20, 55, 100)',

                   font={

                       'size': 15,

                       'family': 'sans-serif',

                       'color': 'rgb(255, 255, 255)'

                   },

                   width=1000,

                   height=800,

                   xaxis=dict(title='Output of grapes per year in US, JAPAN and CN'),

                   showlegend=False,

                   margin=dict(l=100, r=100, t=100, b=100),

                   hovermode = False,

                   )


fig = go.Figure(data=data, layout=layout)

py.offline.plot(fig, filename='basic-scatter.html')

现在结果是这样的:

//img1.sycdn.imooc.com//64c1d31400016b3e06100468.jpg

还有一些小问题:

  1. 如何去掉 1990 和 2000 这两个数字以及 1990 和 2000 的白色垂直线?

  2. 如何为1991、1993、1995、1997、1999画白线并以横坐标轴显示所有这些年份?

请更正代码 Versinon 2 以改进它。谢谢你!


查看完整回答
反对 回复 2023-07-27
  • 3 回答
  • 0 关注
  • 108 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信