csv 将非表格、逗号分隔的数据转换为pydantic

dffbzjpn  于 2023-03-27  发布在  其他
关注(0)|答案(1)|浏览(167)

我有一个特殊的csv文件,格式如下:

A;ItemText;1;2
B;1;1.23,99
B;2;9.52,100
C;false

我想把这些数据转换成pydantic模型。
目前我子类化了FieldInfo类:

class CSVFieldInfo(FieldInfo):
    
    def __init__(self, **kwargs: Any):

        self.position = kwargs["position"]
        
        if not isinstance(self.position, int):
            raise ValueError("Position should be integer, got {}".format(type(self.position)))

        super().__init__()

def CSVField(position: int):
    return CSVFieldInfo(position=position)

此外,我还对BaseModel进行了子类化:

class CSVBaseModel(BaseModel):
    
    @classmethod
    def from_string(cls, string: str, sep: str=";"):
        
        # no double definitions
        l = [x.field_info.position for x in cls.__fields__.values()]
        if not len(set(l)) == len(l):
            raise ValueError("At least one position is defined twice")
        
        # here i am stuck on how to populate the model correctly (including nested models)

模型布局如下所示:

class CSVTypeA(CSVBaseModel):
    record_type: Literal["A"] = CSVField(position=0)
    record_text: str = CSVField(position=1)
    num: int = CSVField(position=2)

class CSVFile(CSVBaseModel):
    a: CSVTypeA

csv_string = \
"""A;ItemText;1;2
B;1;1.23,99
B;2;9.52,100
C;false"""

CSVFile.from_string(csv_string)

如何填充pydantic模型“CSVFile”,自动将正确的CSV-Line分配给正确的模型(通过字段“record_type”进行区分)?

avkwfej4

avkwfej41#

主要的问题是,如果不查看记录中的第一个字段,就无法知道记录与哪种类型兼容。
由于没有键-值对,相反,记录基本上只是没有名称的字段列表,因此我们必须在检查之后确定正确的字段名称,以及我们正在处理的记录类型。
这意味着我们需要重新实现Pydantic的区分联合背后的一些魔力。幸运的是,我们可以利用ModelField在其sub_fields_mapping属性中保存 discriminator key -〉sub-field 的字典的事实。
因此,我们仍然可以利用Pydantic提供的一些内置机制,并适当地定义我们的歧视联合。
但首先我们需要定义一些(示例性的)记录类型:

record_types.py

from typing import Literal, Union

from pydantic import BaseModel

class CSVLine(BaseModel):
    record_type: str

    def some_method(self) -> None:
        print(self)

class CSVTypeA(CSVLine):
    record_type: Literal["A"]
    record_text: str
    num: int
    another_num: int

class CSVTypeB(CSVLine):
    record_type: Literal["B"]
    num_foo: int
    num_floaty: float
    num_bar: int

class CSVTypeC(CSVLine):
    record_type: Literal["C"]
    spam: bool

CSVType = Union[CSVTypeA, CSVTypeB, CSVTypeC]

接下来,我们定义模型,将实际的CSV文件表示为一个自定义根类型,它将是一个list的记录类型的区分联合。
我们还将定义自己的类属性__csv_separator__来保存我们用来分割记录的字符串。
为了更容易和更直观地处理这个模型的示例,我们还将定义/覆盖一些用于项访问、字符串表示等的自定义方法。
最后,我们需要实现将CSV文件的各行解析为自定义validator中适当记录类型的示例的整个魔术。

csv_model.py

from collections.abc import Iterator
from typing import Annotated, Any, ClassVar

from pydantic import BaseModel, Field, validator
from pydantic.fields import ModelField

from .record_types import CSVType

class CSVFile(BaseModel):
    __csv_separator__: ClassVar[str] = ";"
    __root__: list[Annotated[CSVType, Field(discriminator="record_type")]]

    def __iter__(self) -> Iterator[CSVType]:  # type: ignore[override]
        yield from self.__root__

    def __getitem__(self, item: int) -> CSVType:
        return self.__root__[item]

    def __str__(self) -> str:
        return str(self.__root__)

    def __repr__(self) -> str:
        return repr(self.__root__)

    @validator("__root__", pre=True, each_item=True)
    def dict_from_string(cls, v: Any, field: ModelField) -> Any:
        if not isinstance(v, str):
            return v  # let default Pydantic validation take over
        record_fields = v.strip().split(cls.__csv_separator__)
        discriminator_key = record_fields[0]
        assert field.sub_fields_mapping is not None
        try:  # Determine the model to validate against
            type_ = field.sub_fields_mapping[discriminator_key].type_
        except KeyError:
            raise ValueError(f"{discriminator_key} is not a valid key")
        assert issubclass(type_, BaseModel)
        field_names = type_.__fields__.keys()
        return dict(zip(field_names, record_fields))

这应该是我们所需要的。
要创建CSVFile的示例,我们只需要任何可迭代的字符串(CSV文件中的行)。与所有自定义根类型一样,我们可以通过使用__root__关键字参数调用__init__方法或将可迭代的字符串传递给parse_obj方法来初始化它。

演示

csv_string = """
A;ItemText;1;2
B;1;1.23;99
B;2;9.52;100
C;false
""".strip()

obj = CSVFile.parse_obj(csv_string.split("\n"))
print(obj[0])
obj[3].some_method()
print(obj.json(indent=4))

输出:
一个一个三个一个一个一个一个一个四个一个一个一个一个一个五个一个
旁注:我们需要类变量__csv_separator__的原因是验证器是一个类方法,它需要知道要使用的分隔符。(就像你在最初的帖子中尝试的那样)并将分隔符作为参数传递,然后临时修改类变量并调用parse_obj,但我认为在子类中全局或选择性地更改分隔符可能更容易。
另外,我认为没有理由显式指定位置,只要记录类型模型的字段(它们的定义顺序)与CSV记录的实际字段匹配。

相关问题