python-3.x 为什么在find_class中允许取消pickle,但却禁止取消pickle使用受限Unpickler的定制类?

zvms9eto  于 2023-03-09  发布在  Python
关注(0)|答案(1)|浏览(125)

我需要反复运行一些代码来训练一个模型,我发现在一次代码迭代后使用pickle来保存我的对象是很有用的,我可以在第二次迭代中加载并使用它。
但是由于pickle存在安全问题,我想使用restricted_loads选项,但是我似乎不能让它在自定义类中工作,下面是一个较小的代码块,在那里我得到了相同的错误:

import builtins
import io
import os
import pickle

safe_builtins = {
    'range',
    'complex',
    'set',
    'frozenset',
    'slice',
}

allow_classes = {
    '__main__.Shape'
}

class RestrictedUnpickler(pickle.Unpickler):

    def find_class(self, module, name):
        # Only allow safe classes from builtins.
        if module == "builtins" and name in safe_builtins | allow_classes:
            return getattr(builtins, name)
        # Forbid everything else.
        raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
                                     (module, name))

def restricted_loads(s):
    """Helper function analogous to pickle.loads()."""
    return RestrictedUnpickler(io.BytesIO(s)).load()

class Person:
    def __init__(
        self,
        name: str,
        age: int,
    ):
        self.name = name
        self.age = age

class Shape:
    def __init__(
        self,
        name: Person,
        n: int = 50,
    ):
        self.person = Person(
            name = name,
            age = "10",
        )
        self.n = n
        
s = Shape(
    name = "name1",
    n = 30,
)

filepath = os.path.join(os.getcwd(), "temp.pkl")
with open(filepath, 'wb') as outp:
    pickle.dump(s, outp, -1)
    
with open(filepath, 'rb') as inp:
    x = restricted_loads(inp.read())

错误:

UnpicklingError                           Traceback (most recent call last)
Cell In[20], line 63
     60     pickle.dump(s, outp, -1)
     62 with open(filepath, 'rb') as inp:
---> 63     x = restricted_loads(inp.read())

Cell In[20], line 30, in restricted_loads(s)
     28 def restricted_loads(s):
     29     """Helper function analogous to pickle.loads()."""
---> 30     return RestrictedUnpickler(io.BytesIO(s)).load()

Cell In[20], line 25, in RestrictedUnpickler.find_class(self, module, name)
     23     return getattr(builtins, name)
     24 # Forbid everything else.
---> 25 raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
     26                              (module, name))

UnpicklingError: global '__main__.Shape' is forbidden
v8wbuo2f

v8wbuo2f1#

您只允许来自模块builtins的类。
但是__main__.Shape是模块__main__中名为Shape的类,而不是模块builtins中名为__main__.Shape的类。
所以一个显而易见的解决办法就是改变

if module == "builtins" and name in safe_builtins | allow_classes:
    return getattr(builtins, name)

if module == "builtins" and name in safe_builtins:
    return getattr(builtins, name)
elif module == "__main__" and name == "Shape":
    return Shape

相关问题