rust 实现AsyncWrite到hyper Sender时的生存期错误

zqdjd7g9  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(116)

我需要将hyper::body::Sender转换为tokio::io::AsyncWrite并将其传递给我的一个可重用函数。该函数与平台无关,可用于任何io操作。这就是为什么我将AsyncWrite作为参数。
首先,我尝试使用stream-body机箱,发现它使用旧版本的时雄。所以我决定将AsyncWrite实现为Sender。然后我在结构中存储未来时得到了一个生存期错误。
这是我的尝试:-Playground

use hyper::{Request, Body, body::Sender, Response}; // 0.14.26
use futures::{future::BoxFuture, Future}; // 0.3.28
use std::task::Poll;
use pin_project::pin_project; // 1.1.0
use tokio::io::AsyncWrite; // 1.28.2
use bytes::Bytes; // 1.4.0

#[pin_project]
pub struct SenderWriter {
    sender: Sender,
    #[pin]
    write_fut: Option<BoxFuture<'static, hyper::Result<()>>>,
    last_len: usize
}

impl SenderWriter {
    pub fn new(sender: Sender) -> SenderWriter {
        SenderWriter { sender, write_fut: None, last_len: 0 }
    }
}

impl AsyncWrite for SenderWriter {
    fn poll_write(
            self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
            buf: &[u8],
        ) -> Poll<Result<usize, std::io::Error>> {
        let mut this = self.project();
        
        if this.write_fut.is_none() {
            // Storing the last buffer length in memory
            *this.last_len = buf.len();
            // Creating the future
            let fut = this.sender.send_data(Bytes::copy_from_slice(buf));
            *this.write_fut = Some(Box::pin(fut));
        }

        // Keeping length in memory to send with poll result
        let last_len = this.last_len.clone();

        let polled = this.write_fut.as_mut().as_pin_mut().unwrap().poll(cx);

        if polled.is_ready() {
            // Resetting to accept other set ot bytes
            *this.last_len = 0;
            *this.write_fut = None;
        }

        polled.map(move |res|res.map(|_|last_len).map_err(|e|std::io::Error::new(std::io::ErrorKind::Other, e)))
    }

    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
        let this = self.project();
        let res = this.sender.poll_ready(cx);
        res.map(|r|r.map_err(|e|std::io::Error::new(std::io::ErrorKind::Other, e)))
    }

    fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
        self.poll_flush(cx)
    }
}

pub async fn my_reusable_fn<W: AsyncWrite+ Send + Unpin + 'static>(_writer: W) {
    
}

pub async fn download_handler(_req: Request<Body>) -> Response<Body> {
    let (sender, body) = Body::channel();
    let sender_writer = SenderWriter::new(sender);
    tokio::spawn(my_reusable_fn(sender_writer));
    Response::builder().body(body).unwrap()
}

然后,我将BoxFuture中的'static生命周期参数更改为通用生命周期参数。但是self.project()语句返回了生存期错误。

n3h0vuf2

n3h0vuf21#

Sender::send_data()所做的就是等待发送方准备就绪,然后调用try_send_data()。我们可以手动完成:

use std::io::{Error, ErrorKind};
use std::pin::Pin;
use std::task::{ready, Context, Poll};

use hyper::body::Sender;
use tokio::io::AsyncWrite;

pub struct SenderWriter {
    sender: Sender,
}

impl SenderWriter {
    pub fn new(sender: Sender) -> SenderWriter {
        SenderWriter { sender }
    }
}

impl AsyncWrite for SenderWriter {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, Error>> {
        ready!(self
            .sender
            .poll_ready(cx)
            .map_err(|e| Error::new(ErrorKind::Other, e))?);

        match self.sender.try_send_data(Box::<[u8]>::from(buf).into()) {
            Ok(()) => Poll::Ready(Ok(buf.len())),
            Err(_) => Poll::Ready(Err(Error::new(ErrorKind::Other, "Body closed"))),
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        let res = self.sender.poll_ready(cx);
        res.map_err(|e| Error::new(ErrorKind::Other, e))
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        self.poll_flush(cx)
    }
}

相关问题