如何在Rust中将函数应用于polars DataFrame的多个列

pxq42qpu  于 2023-06-06  发布在  其他
关注(0)|答案(3)|浏览(224)

我想应用一个用户定义的函数,它将一些输入(对应于polars DataFrame中的一些列)应用于Rust中polars DataFrame的一些列。我使用的模式如下。我想知道这是最好的做法吗?

fn my_filter_func(col1: &Series, col2: &Series, col2: &Series) -> ReturnType {
    let it = (0..n).map(|i| {
        let col1 = match col.get(i) {
            AnyValue::UInt64(val) => val,
            _ => panic!("Wrong type of col1!"),
        };
        // similar for col2 and col3
        // apply user-defined function to col1, col2 and col3
    }
    // convert it to a collection of the required type
}
p8h8hvxi

p8h8hvxi1#

您可以将Series向下转换为要迭代的正确类型,然后使用rust迭代器应用逻辑。

fn my_black_box_function(a: f32, b: f32) -> f32 {
    // do something
    a
}

fn apply_multiples(col_a: &Series, col_b: &Series) -> Float32Chunked {
    match (col_a.dtype(), col_b.dtype()) {
        (DataType::Float32, DataType::Float32) => {
            let a = col_a.f32().unwrap();
            let b = col_b.f32().unwrap();

            a.into_iter()
                .zip(b.into_iter())
                .map(|(opt_a, opt_b)| match (opt_a, opt_b) {
                    (Some(a), Some(b)) => Some(my_black_box_function(a, b)),
                    _ => None,
                })
                .collect()
        }
        _ => panic!("unpexptected dtypes"),
    }
}

Lazy API

您不必离开惰性API就可以访问my_black_box_function
我们可以收集我们想要在Struct数据类型中应用的列,然后在该Series上应用闭包。

fn apply_multiples(lf: LazyFrame) -> Result<DataFrame> {
    df![
        "a" => [1.0, 2.0, 3.0],
        "b" => [3.0, 5.1, 0.3]
    ]?
    .lazy()
    .select([concat_lst(["col_a", "col_b"]).map(
        |s| {
            let ca = s.struct_()?;

            let b = ca.field_by_name("col_a")?;
            let a = ca.field_by_name("col_b")?;
            let a = a.f32()?;
            let b = b.f32()?;

            let out: Float32Chunked = a
                .into_iter()
                .zip(b.into_iter())
                .map(|(opt_a, opt_b)| match (opt_a, opt_b) {
                    (Some(a), Some(b)) => Some(my_black_box_function(a, b)),
                    _ => None,
                })
                .collect();

            Ok(out.into_series())
        },
        GetOutput::from_type(DataType::Float32),
    )])
    .collect()
}
rryofs0p

rryofs0p2#

我发现对我有用的解决方案是map_multiple(我的理解-如果没有groupby/agg,就使用它)或apply_multiple(我的理解-如果你有groupby/agg)。或者,您也可以使用map_many或apply_many。见下文。

use polars::prelude::*;
use polars::df;

fn main() {
    let df = df! [
        "names" => ["a", "b", "a"],
        "values" => [1, 2, 3],
        "values_nulls" => [Some(1), None, Some(3)],
        "new_vals" => [Some(1.0), None, Some(3.0)]
    ].unwrap();

    println!("{:?}", df);

    //df.try_apply("values_nulls", |s: &Series| s.cast(&DataType::Float64)).unwrap();

    let df = df.lazy()
        .groupby([col("names")])
        .agg( [
            total_delta_sens().sum()
        ]
        );

    println!("{:?}", df.collect());
}

pub fn total_delta_sens () -> Expr {
    let s: &mut [Expr] = &mut [col("values"), col("values_nulls"),  col("new_vals")];

    fn sum_fa(s: &mut [Series])->Result<Series>{
        let mut ss = s[0].cast(&DataType::Float64).unwrap().fill_null(FillNullStrategy::Zero).unwrap().clone();
        for i in 1..s.len(){
            ss = ss.add_to(&s[i].cast(&DataType::Float64).unwrap().fill_null(FillNullStrategy::Zero).unwrap()).unwrap();
        }
        Ok(ss) 
    }

    let o = GetOutput::from_type(DataType::Float64);
    map_multiple(sum_fa, s, o)
}

这里total_delta_sens只是一个方便的 Package 函数。你可以直接在.agg([])或.with_columns([])中使用它:lit::<f64>(0.0).map_many(sum_fa, &[col("norm"), col("uniform")], o)
在sum_fa中,你可以像Richie已经提到的那样向下转换到ChunkedArray和.iter()甚至.par_iter(),希望这能有所帮助

0pizxfdo

0pizxfdo3#

对于Polars version =“0.30”,用途:

lazyframe
.with_columns([
    cols(col1, col2, ..., colN)
   .apply(|series| 
       some_function(series), 
       GetOutput::from_type(DataType::Float64)
   )
]);

The Cargo.

[dependencies]
polars = { version = "0.30", features = [
    "lazy", # Lazy API
    "round_series", # round underlying float types of Series
] }

main()函数:

use std::error::Error;

use polars::{
    prelude::*,
    datatypes::DataType,
};

fn main()-> Result<(), Box<dyn Error>> {

    let dataframe01: DataFrame = df!(
        "column integers"  => &[1, 2, 3, 4, 5, 6],
        "column float64 A" => [23.654, 0.319, 10.0049, 89.01999, -3.41501, 52.0766],
        "column options"   => [Some(28), Some(300), None, Some(2), Some(-30), None],
        "column float64 B" => [23.6499, 0.399, 10.0061, 89.0105, -3.4331, 52.099999],
    )?;

    println!("dataframe01: {dataframe01}\n");

    let columns_with_float64: Vec<&str> = vec![
        "column float64 A",
        "column float64 B",
    ];

    // Format only the columns with float64

    let lazyframe: LazyFrame = dataframe01
        .lazy()
        .with_columns([
            cols(columns_with_float64)
            .apply(|series| 
                Ok(Some(series.round(2)?)), 
                GetOutput::from_type(DataType::Float64)
            )
         ]);
    
    let dataframe02: DataFrame = lazyframe.collect()?;
    
    println!("dataframe02: {dataframe02}\n");

    let series_a: Series = Series::new("column float64 A", &[23.65, 0.32, 10.00, 89.02, -3.42, 52.08]);
    let series_b: Series = Series::new("column float64 B", &[23.65,  0.4, 10.01, 89.01, -3.43, 52.1]);

    assert_eq!(dataframe02.column("column float64 A")?, &series_a);
    assert_eq!(dataframe02.column("column float64 B")?, &series_b);

    Ok(())
}

输出:

dataframe01: shape: (6, 4)
┌─────────────────┬──────────────────┬────────────────┬──────────────────┐
│ column integers ┆ column float64 A ┆ column options ┆ column float64 B │
│ ---             ┆ ---              ┆ ---            ┆ ---              │
│ i32             ┆ f64              ┆ i32            ┆ f64              │
╞═════════════════╪══════════════════╪════════════════╪══════════════════╡
│ 1               ┆ 23.654           ┆ 28             ┆ 23.6499          │
│ 2               ┆ 0.319            ┆ 300            ┆ 0.399            │
│ 3               ┆ 10.0049          ┆ null           ┆ 10.0061          │
│ 4               ┆ 89.01999         ┆ 2              ┆ 89.0105          │
│ 5               ┆ -3.41501         ┆ -30            ┆ -3.4331          │
│ 6               ┆ 52.0766          ┆ null           ┆ 52.099999        │
└─────────────────┴──────────────────┴────────────────┴──────────────────┘

dataframe02: shape: (6, 4)
┌─────────────────┬──────────────────┬────────────────┬──────────────────┐
│ column integers ┆ column float64 A ┆ column options ┆ column float64 B │
│ ---             ┆ ---              ┆ ---            ┆ ---              │
│ i32             ┆ f64              ┆ i32            ┆ f64              │
╞═════════════════╪══════════════════╪════════════════╪══════════════════╡
│ 1               ┆ 23.65            ┆ 28             ┆ 23.65            │
│ 2               ┆ 0.32             ┆ 300            ┆ 0.4              │
│ 3               ┆ 10.0             ┆ null           ┆ 10.01            │
│ 4               ┆ 89.02            ┆ 2              ┆ 89.01            │
│ 5               ┆ -3.42            ┆ -30            ┆ -3.43            │
│ 6               ┆ 52.08            ┆ null           ┆ 52.1             │
└─────────────────┴──────────────────┴────────────────┴──────────────────┘

相关问题