如何将FastAPI中的Pydantic模型转换为Pandas DataFrame?

wbrvyc0a  于 2023-01-07  发布在  其他
关注(0)|答案(1)|浏览(324)

我正在尝试将Pydantic模型转换为Pandas DataFrame,但是我遇到了各种错误。
下面是代码:

from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
import pickle
import sklearn
import pandas as pd
import numpy as np

class Userdata(BaseModel):
  current_res_month_dec: Optional[int] = 0
  current_res_month_nov:  Optional[int] = 0

async def return_recurrent_user_predictions_gb(user_data: Userdata):

      empty_dataframe =  pd.DataFrame([Userdata(**{
      'current_res_month_dec': user_data.current_res_month_dec,
      'current_res_month_nov': user_data.current_res_month_nov})], ignore_index=True)

这是在我的本地环境中尝试通过/docs执行时返回的DataFrame

Response body
Download
{
  "0": {
    "0": [
      "current_res_month_dec",
      0
    ]
  },
  "1": {
    "0": [
      "current_res_month_nov",
      0
    ]
  }

但如果我用这个DataFrame来预测

model_has_afternoon = pickle.load(open('./models/model_gbclf_prob_current_product_has_afternoon.pickle', 'rb'))
result_afternoon = model_has_afternoon.predict_proba(empty_dataframe)[:, 1]

我得到这个错误:

ValueError: setting an array element with a sequence.

我以前试过构建自己的DataFrame,预测应该可以与DataFrame一起工作。

fivyi3re

fivyi3re1#

首先需要使用Pydantic的dict()方法将Pydantic模型转换为字典。注意,其他方法,如Python的dict()函数和.__dict__属性,已被发现比Pydantic的dict()方法更快但是,由于您使用的是Pydantic模型,因此最好使用Pydantic的dict()方法,然后将字典传递给pandas.DataFrame()(用方括号括起来);例如pd.DataFrame([data.dict()])。如this answer中所述,当需要将传递的dict的键设置为 columns,将值设置为 rows 时,可以使用此方法。如果需要指定不同的方向,也可以使用pandas.DataFrame.from_dict()

工作示例
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
import pandas as pd

app = FastAPI()

class Userdata(BaseModel):
  col1: Optional[int] = 0
  col2:  Optional[int] = 0
  col3:  str = "foo"

@app.post('/submit')
def submit_data(data: Userdata):
    df = pd.DataFrame([data.dict()])
    return "Success"
更多选项

正如您提到的,您希望使用DataFrame进行机器学习预测,请注意,还有一些其他选项可以将数据传递给predict()predict_proba()函数,而无需创建DataFrame

model.predict([[data.col1, data.col2, data.col3]])

以及

model.predict([list(data.dict().values())])

请查看this answer以了解更多细节。如果您还需要使用JSON格式的DataFrame响应客户端,请查看here

相关问题