c++ 如何将torch路径添加到Pybind11扩展?

xyhw6mcr  于 2023-06-25  发布在  其他
关注(0)|答案(1)|浏览(111)

我正在用C++编写一个python扩展,使用pybind11,我正在尝试创建setup.py文件,这是我目前为止所做的:

from glob import glob

import setuptools

from pybind11.setup_helpers import Pybind11Extension, build_ext

ext_modules = [
    Pybind11Extension(
        "my_ext",
        sorted(glob("src/my_ext/**/*.cc")),
    )
]

setuptools.setup(
    name="my_ext",
    version="0.1",
    package_dir={"": "src/my_ext/python/"},
    cmdclass={"build_ext": build_ext},
    ext_modules=ext_modules,
)

但是,当我运行pip install .时,我得到了这个错误:

In file included from src/my_ext/cc/thing.cc:7:
      src/my_ext/cc/thing.h:9:10: fatal error: 'torch/torch.h' file not found
      #include <torch/torch.h>
               ^~~~~~~~~~~~~~~
      1 error generated.
      error: command '/usr/bin/clang' failed with exit code 1

有没有什么参数可以传递给Pybind11Extension,让它找到torch并成功构建?

dwthyt8l

dwthyt8l1#

我自己想出来的。只需要使用一些pytorch的helper函数,并查看torch.utils.cpp_extension.CppExtension以获得指导:

import os
from glob import glob

import setuptools
from pybind11.setup_helpers import Pybind11Extension, build_ext
from torch.utils.cpp_extension import include_paths, library_paths

include_dirs = include_paths()
torch_library_paths = library_paths()
libraries = ["c10", "torch", "torch_cpu", "torch_python"]

ext_modules = [
    Pybind11Extension(
        "my_ext",
        sorted(glob("src/my_ext/**/*.cc")),
        include_dirs=include_dirs,
        libraries=libraries,
        library_dirs=torch_library_paths,
    )
]

setuptools.setup(
    name="my_ext",
    version="0.1",
    cmdclass={"build_ext": build_ext},
    ext_modules=ext_modules,
)

相关问题