Rust中的共享内存

qmb5sa22  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(92)

工作环境:
macOS索诺马Ver.14.0(M1 mac)Rust Ver.1.65.0
我想做的是:我想在多线程之间共享一个vec和一个[u8;128]元素的数组。共享时我想执行的要求如下。
1.整个vec必须可读
1.为了能够在vec中重写特定[u8; 128]类型的元素,
1.能够将[u8; 128]类型的数据插入vec
下面是我写的代码,但是这段代码最多可以做到阅读,但是有一个问题就是写的没有体现出来,如果我运行这段代码,然后在执行它的计算机上运行一次下面的命令

nc -v localhost 50051

个字符
将被输出。到目前为止这是正确的,但是第二次运行时的数据输出与第一次运行时相同。我的意图是第二个元素将输出具有3个填充的数据,如下所示,因为我在第一次运行中更新数据。

[[0u8; 128],[3u8; 128],[2u8; 128]]


我猜我对Arc的使用是错误的,它实际上是SharedData的克隆,而不是SharedData的引用,但我不知道如何识别这一点。我如何修复代码使其按预期工作?
main.rs:

use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use tokio_task_pool::Pool;

struct SharedData {
    data: Arc<RwLock<Vec<[u8; 128]>>>
}

impl SharedData {
    fn new(data: RwLock<Vec<[u8; 128]>>) -> Self {
        Self {
            data: Arc::new(data)
        }
    }

    fn update(&self, index: usize, update_data: [u8; 128]) {
        let read_guard_for_array = self.data.read().unwrap();
        let write_lock = RwLock::new((*read_guard_for_array)[index]);
        let mut write_guard_for_item = write_lock.write().unwrap();
        *write_guard_for_item = update_data;
    }
}

fn socket_to_async_tcplistener(s: socket2::Socket) -> std::io::Result<tokio::net::TcpListener> {
    std::net::TcpListener::from(s).try_into()
}

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    let read_guard = db_arc.data.read().unwrap();
    println!("In process() read: {:?}", *read_guard);
    db_arc.update(1, [3u8; 128]);
}

async fn serve(_: usize, tcplistener_arc: Arc<tokio::net::TcpListener>, db_arc: Arc<SharedData>) {
    let task_pool_capacity = 10;

    let task_pool = Pool::bounded(task_pool_capacity)
        .with_spawn_timeout(Duration::from_secs(300))
        .with_run_timeout(Duration::from_secs(300));
    
    loop {
        let (stream, _) = tcplistener_arc.as_ref().accept().await.unwrap();
        let db_arc_clone = db_arc.clone();

        task_pool.spawn(async move {
            process(stream, db_arc_clone).await;
        }).await.unwrap();
    }
}

#[tokio::main]
async fn main() {
    let addr: std::net::SocketAddr = "0.0.0.0:50051".parse().unwrap();
    let soc2 = socket2::Socket::new(
        match addr {
            SocketAddr::V4(_) => socket2::Domain::IPV4,
            SocketAddr::V6(_) => socket2::Domain::IPV6,
        },
        socket2::Type::STREAM,
        Some(socket2::Protocol::TCP)
    ).unwrap();
    
    soc2.set_reuse_address(true).unwrap();
    soc2.set_reuse_port(true).unwrap();
    soc2.set_nonblocking(true).unwrap();
    soc2.bind(&addr.into()).unwrap();
    soc2.listen(8192).unwrap();

    let tcp_listener = Arc::new(socket_to_async_tcplistener(soc2).unwrap());

    let mut vec = vec![
        [0u8; 128],
        [1u8; 128],
        [2u8; 128],
    ];

    let share_db = Arc::new(SharedData::new(RwLock::new(vec)));
    let mut handlers = Vec::new();
    for i in 0..num_cpus::get() - 1 {
        let cloned_listener = Arc::clone(&tcp_listener);
        let db_arc = share_db.clone();

        let h = std::thread::spawn(move || {
            tokio::runtime::Builder::new_current_thread()
                .enable_all()
                .build()
                .unwrap()
                .block_on(serve(i, cloned_listener, db_arc));
        });
        handlers.push(h);
    }

    for h in handlers {
        h.join().unwrap();
    }
}


Cargo.toml:

[package]
name = "tokio-test"
version = "0.1.0"
edition = "2021"

[dependencies]
log = "0.4.20"
env_logger = "0.10.0"
tokio = { version = "1.34.0", features = ["full"] }
tokio-stream = { version = "0.1.14", features = ["net"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_yaml = "0.9.27"
serde_derive = "1.0.193"
mio = {version="0.8.9", features=["net", "os-poll", "os-ext"]}
num_cpus = "1.16.0"
socket2 = { version="0.5.5", features = ["all"]}
array-macro = "2.1.8"
tokio-task-pool = "0.1.5"
argparse = "0.2.2"

xqnpmsa8

xqnpmsa81#

我没有看完整的代码,但有一些错误。

fn update()

fn update(&self, index: usize, update_data: [u8; 128]) {
        let read_guard_for_array = self.data.read().unwrap();
        let write_lock = RwLock::new((*read_guard_for_array)[index]);
        let mut write_guard_for_item = write_lock.write().unwrap();
        *write_guard_for_item = update_data;
    }

字符串
这不是使用RwLock的方式:

  • 如果要修改数据,请使用self.data.write(),而不是self.data.read();
  • 我不知道你打算用第二个RwLock做什么,但它是无用的。

相反,做一些类似于

fn update(&self, index: usize, update_data: [u8; 128]) {
        let write_guard_for_array = self.data.write().unwrap();
        write_guard_for_array[index] = update_data;
    }

fn process()

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    let read_guard = db_arc.data.read().unwrap();
    println!("In process() read: {:?}", *read_guard);
    db_arc.update(1, [3u8; 128]);
}


一般来说,你可能不应该直接访问db_arc.data。但除此之外,一旦你修复了函数update(),这将导致死锁:
1.您获取db_arc.data.read()。根据RwLock的定义,这意味着在读取锁被释放之前,没有人可以修改db_arc.data的内容。
1.读锁仅在作用域结束时释放。
1.在作用域结束之前,您调用update(),它将尝试获取data.write()。但它无法获取它,直到读取锁被释放。
你可能想要的东西沿着线:

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    {
      let read_guard = db_arc.data.read().unwrap();
      println!("In process() read: {:?}", *read_guard);
    } // End of scope, `read_guard` is released.
    db_arc.update(1, [3u8; 128]);
}

时雄+线程

你混合使用了线程和时雄。这在理论上是可行的,但有风险。两种选择都是有效的,但我建议选择其中之一。通常,如果你有很多I/O(例如网络请求或磁盘访问),选择时雄,或者如果你有很多CPU使用,选择线程。

eqzww0vc

eqzww0vc2#

fn update(&self, index: usize, update_data: [u8; 128]) {
    let read_guard_for_array = self.data.read().unwrap();
    let write_lock = RwLock::new((*read_guard_for_array)[index]);

字符串
这将创建数据的副本并将其 Package 在无用的RwLock中(无用的原因是该副本始终保存在单个线程中)。

let mut write_guard_for_item = write_lock.write().unwrap();
    *write_guard_for_item = update_data;
}


这会修改副本,然后在函数结束时立即丢弃副本。
相反,你需要锁定你已经拥有的RwLock

fn update(&self, index: usize, update_data: [u8; 128]) {
    let mut write_guard = self.data.write().unwrap();
    write_guard[index] = update_data;
}


请注意,没有办法只为特定的项目获得写锁,并为整个数组获得读锁:读锁和写锁必须与相同的数据相关。这意味着您还需要在更新之前释放读锁:

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    let read_guard = db_arc.data.read().unwrap();
    println!("In process() read: {:?}", *read_guard);
    drop (read_guard);
    db_arc.update(1, [3u8; 128]);
}

相关问题