gpt-2 模型提示符是否支持多行输入?

5q4ezhmt  于 6个月前  发布在  其他
关注(0)|答案(5)|浏览(122)

目前我还没有找到输入多个段落或列表格式的方法。输入和所有其他换行符方法我已经尝试过不工作。

vqlkdk9b

vqlkdk9b1#

你可以通过引入自己的换行符并在输入中使用它来解决这个问题。模型通常会选择模式并使用符号本身。不过,我也希望能够使用换行符。

insrf1ej

insrf1ej2#

你可以在你的代码中使用\n。例如,你可以读入一个带有行结尾的文本文件,它就可以工作了。这可能需要稍微修改代码,但是如果我能做到,任何人都可以!

kzipqqlq

kzipqqlq3#

嘿,请我尝试做同样的事情,你可以请你告诉我如何修改代码,这样我就可以让它采取多个输入,提前感谢!我知道我是一个晚了这一点,但我目前的工作,我真的卡住了,到目前为止找不到任何帮助.

hyrbngr7

hyrbngr74#

这是我根据这个仓库中的示例编写的一些代码。只需将其放在与其他源代码相同的文件夹中。您可能需要更改模型的目录,因为我使用了一个我为我的目的进行了微调的个人目录。

import json
import os
import numpy as np
import tensorflow as tf

import model as model
import sample as sample
import encoder as encoder

class Generator():

    def __init__(self, sess, length=40, temperature=0.9, top_k=40):
    
        seed = None
        batch_size=1
        model_path='models/sanjeev-model-curated'
        self.sess = sess
    
        self.enc = encoder.get_encoder(model_path, '') # Note that the '' is to trick the encoder since we have the model name in the path
        hparams = model.default_hparams()
        with open(os.path.join(model_path, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))  

        self.context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        self.output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=self.context,
            batch_size=batch_size,
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(model_path)
        saver.restore(self.sess, ckpt)
            
        
    def generate(self, prompt):
        context_tokens = self.enc.encode(prompt)
        out = self.sess.run(self.output, feed_dict={
                self.context: [context_tokens for _ in range(1)]
            })[:, len(context_tokens):]

        text = self.enc.decode(out[0])
        return text

你只需要传入一个新的session(在上下文管理器中使用,因为急切执行的奇怪之处)。下面给出了一个简单的flask应用程序的例子。

from generator import Generator
import tensorflow as tf
import json
from getpass import getpass
import random
from flask import Flask, request, Response, make_response, jsonify
from waitress import serve
import regex

with tf.Session(graph=tf.Graph()) as sess:
    generator = Generator(sess)
    re = regex.compile('[a-zA-Z]')


    def PruneResult(text):
        text = text.split('\n')
        if re.search(text[0]):
            return text[0]
        else:
            for t in text[1:]:
                mess = t.split(':')
                mess = ':'.join(mess[1:])
                if re.search(mess):
                    return mess
        return 'Why are you so confusing, humans?'
            

    def onMessage(data):
            print('Raw Input: ' + str(data))
            new_data = {}
            prompt = ""
            for key in data.keys():
                new_data[int(key)] = data[key]
            data = new_data
            del new_data
            for key in sorted(data.keys()):
                prompt += f"{data[key][0]}: {data[key][1]}\n"
            prompt += "Damien: "
            print('Prompt: ' + prompt)
            result = generator.generate(prompt)
            print('Uncut result: ' + result)
            mess = PruneResult(result)
            print('Message: ' + mess)

            return mess



    print('Time to start')
    app = Flask(__name__)

    @app.route('/', methods=['POST'])
    def respond():
        try:
            authHeader = request.headers['Authorization']
        except:
            return make_response(jsonify({}), 400)
        if authHeader != "Bearer CHANGE_THIS": # Change this key for security, make sure to remember it to add it to your header
            return make_response(jsonify({}), 400)
        if request.remote_addr != "10.0.0.1" and request.remote_addr != "127.0.0.1": # Change these addresses to the addresses you want to accept requests from
            return make_response(jsonify({}), 400)
        if request.is_json:
            return make_response(jsonify(onMessage(request.json)), 201)
        else:
            print(request)
            return make_response(jsonify({}), 400)

    if __name__ == "__main__":
        serve(app, host='10.0.0.2', port=5000) # Change the host to 0.0.0.0 to make it public or 127.0.0.1 so it only works on your machine

curl -H "Authorization: Bearer CHANGE_THIS" -H "Content-Type: application/json" -X POST -d '{"1": ["James", "Hi Micheal, how are you?"], "2":["Micheal","Good thanks, just finished my project"],"3":["James","Thanks good, I hope you enjoyed it"]}' http://10.0.0.2:5000/是一个触发此操作的curl命令示例

oyt4ldly

oyt4ldly5#

我用多行标记法做了手脚:
prompt = ('"""' + "\n" + multiline_text + "\n" +'"""')

相关问题