rust 可以减去列表列吗?

zengzsys  于 2023-02-19  发布在  其他
关注(0)|答案(2)|浏览(130)

我想用数据来运算这个表达式。
但它给出List(float64)类型没有该操作的错误。
我猜列表类型没有实现元素级操作。

(col("vec").last() - col("vec")).abs().sum()
vec
   ---------
   list[f64]
   ============================
0: [-0.000000, -1.11111, ..., ]
1: [-2.222222,  3.33333, ..., ]
...
n: [ 8.888888, -9.99999, ..., ]

那么如果我想用最后一行减去每一行,最好的方法是什么?
下面是我想做的:

0: sum(abs([ 8.888888, -9.99999, ..., ] - [-0.000000, -1.11111, ..., ]))
1: sum(abs([ 8.888888, -9.99999, ..., ] - [-2.222222,  3.33333, ..., ]))
...
n: sum(abs([ 8.888888, -9.99999, ..., ] - [ 8.888888, -9.99999, ..., ]))
xe55xuns

xe55xuns1#

我可以通过将列表转换为结构体,取消嵌套所述结构体以在(vec, len(vec))形状的浮点数DataFrame中获得等效数据,然后从那里执行标准操作并转换回列表来欺骗答案:

>>> x = pl.DataFrame(data={'vec' : [[0., -1-1/9], [-2-2/9, 3+3/9], [8+8/9, -10.]]})
>>> (x
 .select(dummy=pl.col('vec').arr.to_struct())
 .unnest('dummy')
 .select(vec=pl.concat_list(abs(pl.all().last() - pl.all()))
)

shape: (3, 1)
┌────────────────────────┐
│ vec                    │
│ ---                    │
│ list[f64]              │
╞════════════════════════╡
│ [8.888889, 8.888889]   │
│ [11.111111, 13.333333] │
│ [0.0, 0.0]             │
└────────────────────────┘
sgtfey8w

sgtfey8w2#

此代码在Rust v1.67上测试,用于v0.27.2中的极性。
在Cargo. toml中添加以下功能:

[dependencies]
polars = { version = "*", features = [ "lazy", "lazy_regex", "list_eval" ] }
color-eyre = "*"

主要功能:

use color_eyre::Result;
use polars::prelude::*;
use std::error::Error;

fn main() -> Result<(), Box<dyn Error>> {
    let row1 = vec![0.0, -1.0 - 1.0 / 9.0, 1.77];
    let row2 = vec![-2.0 - 2.0 / 9.0, 3.0 + 3.0 / 9.0, 2.93];
    let row3 = vec![3.0 + 3.0 / 9.0, -4.0 - 1.0 / 9.0, 3.56];
    let row4 = vec![8.0 + 8.0 / 9.0, -10.0, 7.26];

    let series1: Series = Series::new("a", &row1);
    let series2: Series = Series::new("b", &row2);
    let series3: Series = Series::new("c", &row3);
    let series4: Series = Series::new("d", &row4);

    let list = Series::new("vec", &[series1, series2, series3, series4]);

    let df: DataFrame = DataFrame::new(vec![list])?;
    println!("df:\n{df}\n");

    let mut lazyframe = df.lazy();
    let mut new_columns: Vec<String> = Vec::new();

    for i in 0..row1.len() {
        let column_name: String = format!("vec_{i}");
        let subtraction: String = format!("sub_{i}");
        new_columns.extend([column_name.clone(), subtraction.clone()]);

        lazyframe = lazyframe
            .with_columns([
                // split list into new intermediate columns
                col("vec").arr().get(lit(i as i64)).alias(&column_name),
                //col("vec").arr().eval(lit(2.0) * col(""), true)
                //.alias("test multiplication by 2"),
            ])
            .with_columns([
                (col(&column_name).last() - col(&column_name))
                .apply(absolute_value, GetOutput::from_type(DataType::Float64))
                .alias(&subtraction)
            ]);
    }

    lazyframe = lazyframe
    .select([
        all(), 
        concat_lst([col("^sub_.*$")]).alias("Concat lists")
    ]);

    lazyframe = lazyframe
    .with_columns([
        col("Concat lists").arr().sum().alias("Sum")
    ]);

    // uncomment to discard intermediate columns
    // lazyframe = lazyframe.drop_columns(new_columns);

    println!("dataframe:\n{}\n", lazyframe.collect()?);

    Ok(())
}

absolute_value函数如下所示:

fn absolute_value(str_val: Series) -> Result<Option<Series>, PolarsError> {
    let series: Series = str_val
        .f64()
        .expect("fn absolute_value: series was not an f64 dtype")
        .into_iter()
        .map(|opt_value: Option<f64>| opt_value.map(|value: f64| value.abs()))
        .collect::<Float64Chunked>()
        .into_series();

    Ok(Some(series))
}

初始数据框:

df:
shape: (4, 1)
┌─────────────────────────────┐
│ vec                         │
│ ---                         │
│ list[f64]                   │
╞═════════════════════════════╡
│ [0.0, -1.111111, 1.77]      │
│ [-2.222222, 3.333333, 2.93] │
│ [3.333333, -4.111111, 3.56] │
│ [8.888889, -10.0, 7.26]     │
└─────────────────────────────┘

最终结果是:

dataframe:
shape: (4, 9)
┌─────────────────────────────┬───────────┬───────────┬───────────┬─────┬───────┬───────┬──────────────────────────────┬───────────┐
│ vec                         ┆ vec_0     ┆ sub_0     ┆ vec_1     ┆ ... ┆ vec_2 ┆ sub_2 ┆ Concat lists                 ┆ Sum       │
│ ---                         ┆ ---       ┆ ---       ┆ ---       ┆     ┆ ---   ┆ ---   ┆ ---                          ┆ ---       │
│ list[f64]                   ┆ f64       ┆ f64       ┆ f64       ┆     ┆ f64   ┆ f64   ┆ list[f64]                    ┆ f64       │
╞═════════════════════════════╪═══════════╪═══════════╪═══════════╪═════╪═══════╪═══════╪══════════════════════════════╪═══════════╡
│ [0.0, -1.111111, 1.77]      ┆ 0.0       ┆ 8.888889  ┆ -1.111111 ┆ ... ┆ 1.77  ┆ 5.49  ┆ [8.888889, 8.888889, 5.49]   ┆ 23.267778 │
│ [-2.222222, 3.333333, 2.93] ┆ -2.222222 ┆ 11.111111 ┆ 3.333333  ┆ ... ┆ 2.93  ┆ 4.33  ┆ [11.111111, 13.333333, 4.33] ┆ 28.774444 │
│ [3.333333, -4.111111, 3.56] ┆ 3.333333  ┆ 5.555556  ┆ -4.111111 ┆ ... ┆ 3.56  ┆ 3.7   ┆ [5.555556, 5.888889, 3.7]    ┆ 15.144444 │
│ [8.888889, -10.0, 7.26]     ┆ 8.888889  ┆ 0.0       ┆ -10.0     ┆ ... ┆ 7.26  ┆ 0.0   ┆ [0.0, 0.0, 0.0]              ┆ 0.0       │
└─────────────────────────────┴───────────┴───────────┴───────────┴─────┴───────┴───────┴──────────────────────────────┴───────────┘

相关问题