gpt-2 interactive_conditional_samples不检查提示长度是否大于hparams.n_ctx / 2

mkshixfv  于 3个月前  发布在  其他
关注(0)|答案(8)|浏览(122)

如果长度更大,这将破坏模型,因为单词位置嵌入(wpeTensor)不够大。是否应该添加检查?

p1tboqfb

p1tboqfb1#

欢迎随时发送PR!

vlf7wbxs

vlf7wbxs2#

一直在忙碌学习强化学习,但肯定的是,当我抓住了休息!你看到我的拉请求#119?它只是不值得一个错误的风险?不完全确定,但从我的阅读/心理图,我认为它对模型的输出字面上为零的影响。

u0njafvf

u0njafvf3#

当我有时间的时候,我会测试它。它运行得很好,只是没有在相同的输入和种子上测试过。

rdlzhqv9

rdlzhqv94#

或者你可以考得更好,你有更多的资源,我想

inb24sb2

inb24sb25#

回顾!谢谢你的工作!
如果你能测试一下就太好了(我们这里很忙,忙碌)

hc8w905p

hc8w905p6#

当然,我会去的!我想你也没有比我更好的基础设施来测试

n9vozmp4

n9vozmp47#

它也可能崩溃。有一些错误,我得到的时候,输入是~460字(GPT-2)。阅读它可能与词汇,但我不知道如果这是这种情况在这里。
tensorflow.python.framework.errors_impl.InvalidArgumentError:indices[0,0] = 1024不在[0,1024)中
{{node sample_sequence_1/while/model/GatherV2_1}}

v64noz0r

v64noz0r8#

实际上,修复的方法不是计数编码的令牌并检查它是否没有超过长度变量吗?我们实际上是基于编码的令牌而不是输入的文本本身。我发现当我打印值时,它正好是512,它工作正常,>512崩溃。然而,我有更多修改的代码,所以我可以给一些提示:

context_length = 0
while not raw_text or context_length <= 0 or context_length > length:
[...] [inside loop:]
context_tokens = enc.encode(raw_text)
context_length = len(context_tokens)

这为我解决了问题,它不再崩溃!
我修改了原来的,pull请求:#142

相关问题