json 通过http调用Huggingface API POST-编码/解析问题

hgtggwj0  于 2023-01-22  发布在  其他
关注(0)|答案(1)|浏览(241)

我尝试使用htrr包在R中复制一个Python API调用。
huggingface中的Python代码:

import json
import requests

API_TOKEN = ""

def query(payload='',parameters=None,options={'use_cache': False}):
    API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-neo-2.7B"
        headers = {"Authorization": "TOKEN"}
    body = {"inputs":payload,'parameters':parameters,'options':options}
    response = requests.request("POST", API_URL, headers=headers, data= json.dumps(body))
    try:
      response.raise_for_status()
    except requests.exceptions.HTTPError:
        return "Error:"+" ".join(response.json()['error'])
    else:
      return response.json()[0]['generated_text']

parameters = {
    'max_new_tokens':25,  # number of generated tokens
    'temperature': 0.5,   # controlling the randomness of generations
    'end_sequence': "###" # stopping sequence for generation
}

prompt="Tweet: \"I hate it when my phone battery dies.\"\n" + \
"Sentiment: Negative\n" + \
"###\n" + \
"Tweet: \"My day has been 👍\"\n" + \
"Sentiment: Positive\n" + \
"###\n" + \
"Tweet: \"This is the link to the article\"\n" + \
"Sentiment: Neutral\n"+ \
"###\n" + \
"Tweet: \"This new music video was incredible\"\n" + \
"Sentiment:"             # few-shot prompt

data = query(prompt,parameters,options)

这是可行的-模型通常使用字符串“Positive”来完成提示。
我在R中的尝试:

library(httr)
headers <- c(
  `Authorization` = "TOKEN",
  `Content-Type` = "application/x-www-form-urlencoded"
)

# this was guessed manually from https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api

# does not quite work yet?
# removing \"´does not make a difference
prompt <- "Tweet: I hate it when my phone battery dies.
Sentiment: Negative
###
Tweet: My day has been 👍
Sentiment: Positive
###
Tweet: This is the link to the article
Sentiment: Neutral
###
Tweet: This new music video was incredible
Sentiment:"

data <- c(`inputs` = prompt,
          `max_new_tokens` = 3, 
          `temperature` = 0.5,
          `end_sequence` = "###")

res <- httr::POST(url = "https://api-inference.huggingface.co/models/EleutherAI/gpt-neo-2.7B", 
                  httr::add_headers(.headers=headers), body = data)

content(res)[[1]][[1]]
[1] "Tweet: I hate it when my phone battery dies.\nSentiment: Negative\n###\nTweet: My day has been 👍\nSentiment: Positive\n###\nTweet: This is the link to the article\nSentiment: Neutral\n###\nTweet: This new music video was incredibile\nSentiment:\n3\n0.5\n###\n"

R的输出是“\n3\n0.5\n###\n”-随机的换行符和数字。我怀疑可能是编码或httpr如何解析来自API的json输出有问题。

abithluo

abithluo1#

为您提供httr2包和数据清理建议

library(tidyverse)
library(httr2)

token = "TOKEN" 

prompt = "Tweet: I hate it when my phone battery dies.
Sentiment: Negative
###
Tweet: My day has been 👍
Sentiment: Positive
###
Tweet: This is the link to the article
Sentiment: Neutral
###
Tweet: This new music video was incredible
Sentiment:"

content <- "https://api-inference.huggingface.co/models/EleutherAI/gpt-neo-2.7B" %>% 
  request() %>% 
  req_auth_bearer_token(token) %>% 
  req_body_json(list(
    inputs = prompt,
    max_new_tokens = 3, 
    temperature = 1, 
    end_sequence = "###"
  )) %>%  
  req_perform() %>% 
  resp_body_json(simplifyVector = TRUE)

数据清理

content %>%  
  separate_rows(generated_text, sep = "\n") %>% 
  filter(generated_text != "###") %>%  
  group_by(group = str_c("Col", rep(1:2, length.out = n()))) %>% 
  pivot_wider(names_from = group, values_from = generated_text) %>% 
  unnest(everything())

# A tibble: 4 x 2
  Col1                                           Col2               
  <chr>                                          <chr>              
1 "Tweet: I hate it when my phone battery dies." Sentiment: Negative
2 "Tweet: My day has been \U0001f44d"            Sentiment: Positive
3 "Tweet: This is the link to the article"       Sentiment: Neutral 
4 "Tweet: This new music video was incredible"   Sentiment: Negative

相关问题