rust 自定义分类容器

ulydmbyx  于 2022-12-29  发布在  其他
关注(0)|答案(1)|浏览(139)

我希望有一个点的容器,在运行时,按照到某个给定点的距离排序,在其他语言中,我可以为容器提供一个自定义的比较函数,然而,我知道这在rust中是不可能的。
请考虑以下代码问题:

/// distance between two points
fn distance(a: &(f32, f32), b: &(f32, f32)) -> f32 {
    ((a.0-b.0)*(a.0-b.0) + (a.1-b.1)*(a.1-b.1)).sqrt()
}

fn main() {
  let origin = (1, 1);                 // assume values are provided at runtime
  let mut container = BTreeSet::new(); // should be sorted by distance to origin 
  container.insert((1 ,9));
  container.insert((2 ,2));
  container.insert((1 ,5));
}

在插入之后,我希望容器被排序为[(2,2),(1,5),(1,9)],这个例子使用了BTreeSet,我并不坚持使用它,但是感觉它最接近我的需要。
我不想要一个Vec,我必须在每个insert()之后手动求助。
那么,如何连接distance()origincontainer,最好是不依赖第三方?

bejyjqdl

bejyjqdl1#

我不认为有一种好方法可以做到这一点,除非将原点与每个点沿着存储,以便您可以在Cmp实现中使用它。

use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
use std::collections::BTreeSet;

#[derive(Debug, Clone, Copy)]
struct Point2D {
    origin: (f32, f32),
    point: (f32, f32),
}

impl Point2D {
    fn length(self) -> f32 {
        let (x1, y1) = self.origin;
        let (x2, y2) = self.point;
        ((x1 - x2).powi(2) + (y1 - y2).powi(2)).sqrt()
    }
}

impl PartialEq for Point2D {
    fn eq(&self, rhs: &Self) -> bool {
        self.origin == rhs.origin && self.length() == rhs.length()
    }
}

impl Eq for Point2D {}

impl PartialOrd for Point2D {
    fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
        (self.origin == rhs.origin).then_some(self.cmp(rhs))
    }
}

impl Ord for Point2D {
    fn cmp(&self, rhs: &Self) -> Ordering {
        self.length().total_cmp(&rhs.length())
    }
}

fn main() {
    let origin = (1.0, 1.0); // assume values are provided at runtime
    let mut container = BTreeSet::new(); // should be sorted by distance to origin
    container.insert(Point2D {
        origin,
        point: (1.0, 9.0),
    });
    container.insert(Point2D {
        origin,
        point: (2.0, 2.0),
    });
    container.insert(Point2D {
        origin,
        point: (1.0, 5.0),
    });
    println!("{:?}", container.iter().map(|p| p.point).collect::<Vec<_>>());
    // [(2.0, 2.0), (1.0, 5.0), (1.0, 9.0)]
}

相关问题