python 如何在Huggingface中使用M2M模型进行多目标语言翻译?

4xrmg8kj  于 2023-04-19  发布在  Python
关注(0)|答案(1)|浏览(234)

M2M model接受了约100种语言的训练,能够翻译不同的语言,例如

from transformers import pipeline

m2m100 = pipeline('translation', 'facebook/m2m100_418M', src_lang='en', tgt_lang="de")
m2m100(["hello world", "foo bar"])

[out]:

[{'translation_text': 'Hallo Welt'}, {'translation_text': 'Die Fu Bar'}]

但要启用多个目标转换,用户必须初始化多个管道:

from transformers import pipeline

m2m100_en_de = pipeline('translation', 'facebook/m2m100_418M', src_lang='en', tgt_lang="de")

m2m100_en_fr = pipeline('translation', 'facebook/m2m100_418M', src_lang='en', tgt_lang="fr")

print(m2m100_en_de(["hello world", "foo bar"]))
print(m2m100_en_fr(["hello world", "foo bar"]))

[out]:

[{'translation_text': 'Hallo Welt'}, {'translation_text': 'Die Fu Bar'}]
[{'translation_text': 'Bonjour Monde'}, {'translation_text': 'Le bar Fou'}]

M2M模型中,是否有一种方法可以将多个目标语言和/或源语言使用一个管道?

我试过这个:

from transformers import pipeline

m2m100_en_defr = pipeline('translation', 'facebook/m2m100_418M', src_lang='en', tgt_lang=["de", "fr"])

print(m2m100_en_defr(["hello world", "foo bar"]))

但它抛出了错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_28/3374873260.py in <module>
      3 m2m100_en_defr = pipeline('translation', 'facebook/m2m100_418M', src_lang='en', tgt_lang=["de", "fr"])
      4 
----> 5 print(m2m100_en_defr(["hello world", "foo bar"]))

/opt/conda/lib/python3.7/site-packages/transformers/pipelines/text2text_generation.py in __call__(self, *args, **kwargs)
    364               token ids of the translation.
    365         """
--> 366         return super().__call__(*args, **kwargs)

/opt/conda/lib/python3.7/site-packages/transformers/pipelines/text2text_generation.py in __call__(self, *args, **kwargs)
    163         """
    164 
--> 165         result = super().__call__(*args, **kwargs)
    166         if (
    167             isinstance(args[0], list)

/opt/conda/lib/python3.7/site-packages/transformers/pipelines/base.py in __call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1088                     inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
   1089                 )
-> 1090                 outputs = list(final_iterator)
   1091                 return outputs
   1092             else:

/opt/conda/lib/python3.7/site-packages/transformers/pipelines/pt_utils.py in __next__(self)
    122 
    123         # We're out of items within a batch
--> 124         item = next(self.iterator)
    125         processed = self.infer(item, **self.params)
    126         # We now have a batch of "inferred things".

/opt/conda/lib/python3.7/site-packages/transformers/pipelines/pt_utils.py in __next__(self)
    122 
    123         # We're out of items within a batch
--> 124         item = next(self.iterator)
    125         processed = self.infer(item, **self.params)
    126         # We now have a batch of "inferred things".

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627                 self._reset()  # type: ignore[call-arg]
--> 628             data = self._next_data()
    629             self._num_yielded += 1
    630             if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    669     def _next_data(self):
    670         index = self._next_index()  # may raise StopIteration
--> 671         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672         if self._pin_memory:
    673             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/transformers/pipelines/pt_utils.py in __getitem__(self, i)
     17     def __getitem__(self, i):
     18         item = self.dataset[i]
---> 19         processed = self.process(item, **self.params)
     20         return processed
     21 

/opt/conda/lib/python3.7/site-packages/transformers/pipelines/text2text_generation.py in preprocess(self, truncation, src_lang, tgt_lang, *args)
    313         if getattr(self.tokenizer, "_build_translation_inputs", None):
    314             return self.tokenizer._build_translation_inputs(
--> 315                 *args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang
    316             )
    317         else:

/opt/conda/lib/python3.7/site-packages/transformers/models/m2m_100/tokenization_m2m_100.py in _build_translation_inputs(self, raw_inputs, src_lang, tgt_lang, **extra_kwargs)
    351         self.src_lang = src_lang
    352         inputs = self(raw_inputs, add_special_tokens=True, **extra_kwargs)
--> 353         tgt_lang_id = self.get_lang_id(tgt_lang)
    354         inputs["forced_bos_token_id"] = tgt_lang_id
    355         return inputs

/opt/conda/lib/python3.7/site-packages/transformers/models/m2m_100/tokenization_m2m_100.py in get_lang_id(self, lang)
    379 
    380     def get_lang_id(self, lang: str) -> int:
--> 381         lang_token = self.get_lang_token(lang)
    382         return self.lang_token_to_id[lang_token]
    383 

/opt/conda/lib/python3.7/site-packages/transformers/models/m2m_100/tokenization_m2m_100.py in get_lang_token(self, lang)
    376 
    377     def get_lang_token(self, lang: str) -> str:
--> 378         return self.lang_code_to_token[lang]
    379 
    380     def get_lang_id(self, lang: str) -> int:

TypeError: unhashable type: 'list'

人们会期望输出结果看起来像这样:

{"de": [{'translation_text': 'Hallo Welt'}, {'translation_text': 'Die Fu Bar'}]
 "fr": [{'translation_text': 'Bonjour Monde'}, {'translation_text': 'Le Foo Bar'}]
}

如果我们使用多个管道,模型是mmap和共享的吗?是用多个tokenizer对初始化多个模型?还是用多个tokenizer初始化单个模型?

kfgdxczn

kfgdxczn1#

TLDR(和推荐):

src_lang和tgt_lang是__call__参数,因此您可以在调用管道时更改目标语言:

print(m2m100_en_de(["hello world", "foo bar"], tgt_lang="fr"))
print(m2m100_en_fr(["hello world", "foo bar"]))

输出:

[{'translation_text': 'Bonjour Monde'}, {'translation_text': 'Le bar Fou'}]
[{'translation_text': 'Bonjour Monde'}, {'translation_text': 'Le bar Fou'}]

详细回答:

这将加载模型两次:

from transformers import pipeline

m2m100_en_de = pipeline('translation', 'facebook/m2m100_418M', src_lang='en', tgt_lang="de")

m2m100_en_fr = pipeline('translation', 'facebook/m2m100_418M', src_lang='en', tgt_lang="fr")
print(id(m2m100_en_de.model))
print(id(m2m100_en_fr.model))

输出:

140447096860288
140447751061024

您可以通过直接传递模型而不是模型标识符字符串来在一定程度上解决这个问题:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

t = AutoTokenizer.from_pretrained("facebook/m2m100_418M")

m = AutoModelForSeq2SeqLM.from_pretrained("facebook/m2m100_418M")

m2m100_en_de = pipeline('translation', model=m, tokenizer=t, src_lang='en', tgt_lang="de")
m2m100_en_fr = pipeline('translation', model=m, tokenizer=t, src_lang='en', tgt_lang="fr")
print(m2m100_en_de(["hello world", "foo bar"]))
print(m2m100_en_fr(["hello world", "foo bar"]))
print(id(m2m100_en_de.model))
print(id(m2m100_en_fr.model))

输出:

[{'translation_text': 'Hallo Welt'}, {'translation_text': 'Die Fu Bar'}]
[{'translation_text': 'Bonjour Monde'}, {'translation_text': 'Le bar Fou'}]
139674945076768
139674945076768

请注意,我不建议这样做,因为你仍然会有重复的实体在你的记忆.这只是给予你一个想法,并向你展示一种方法,结果在更少的重复相比,你原来的尝试.也许还检查这个StackOverflow帖子:What is the difference between shallow copy, deepcopy and normal assignment operation?

相关问题