spark:实现defaultparamswritable/mlwritable的正确方法

goucqfw6  于 2021-05-27  发布在  Spark
关注(0)|答案(0)|浏览(204)

我有一个Spark变压器,它的工作 Package 周围的bucketizer和它所做的基本上是分裂列在5桶。
我希望能够:
在pipelinemodel.fit()中使用它
用某个文件中的管道序列化它。
我应该如何实现mlwritable接口(或defaultparamswritable)?
这是我的变压器:

public class BucketizerTransformer extends Transformer {
    private static final long serialVersionUID = 5589399640951989469L;
    private String column;

    BucketizerTransformer(String column) {
        this.column = column;
    }

    @Override
    public String uid() {
        return "CustomTransformer" + serialVersionUID;
    }

    @Override
    public Dataset<Row> transform(Dataset<?> df) {
        Double min = getMinDoubleValue(df);
        Double max = getMaxDoubleValue(df);
        double step = (max - min) / 4;
        double[] splits = {min, min + step, min + 2 * step, min + 3 * step, max};
        Bucketizer bucketizer = new Bucketizer()
                .setInputCol(column)
                .setOutputCol(column + "_bucket")
                .setSplits(splits);
        return bucketizer.transform(df);
    }

    public String getOutputColumn() {
        return column + "_vector";
    }

    public Double getMaxDoubleValue(Dataset<?> df) {
        return (Double) df.groupBy().max(column).collectAsList().get(0).get(0);
    }

    public Double getMinDoubleValue(Dataset<?> df) {
        return (Double) df.groupBy().min(column).collectAsList().get(0).get(0);
    }

    @Override
    public Transformer copy(ParamMap arg0) {
        return null;
    }

    @Override
    public StructType transformSchema(StructType structType) {
        structType = structType.add(column + "_bucket", DataTypes.DoubleType, true);
        return structType;
    }
}

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题