apache flink-filter作为终止条件

jhkqcmku  于 2021-06-24  发布在  Flink
关注(0)|答案(1)|浏览(319)

我用k-均值定义了一个终止条件的过滤器。如果我运行我的应用程序,它总是只计算一次迭代。
我认为问题在于:

DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));

或者过滤器功能:

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}

致以最诚挚的问候,保罗
我的完整代码在这里:

public void run() {   
    //load properties
    Properties pro = new Properties();
    FileSystem fs = null;
    try {
        pro.load(FlinkMain.class.getResourceAsStream("/config.properties"));
        fs = FileSystem.get(new URI(pro.getProperty("hdfs.namenode")),new org.apache.hadoop.conf.Configuration());
    } catch (Exception e) {
        e.printStackTrace();
    }

    int maxIteration = Integer.parseInt(pro.getProperty("maxiterations"));
    String outputPath = fs.getHomeDirectory()+pro.getProperty("flink.output");
    // set up execution environment
    ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    // get input points
    DataSet<GeoTimeDataTupel> points = getPointDataSet(env);
    DataSet<GeoTimeDataCenter> centroids = null;
    try {
        centroids = getCentroidDataSet(env);
    } catch (Exception e1) {
        e1.printStackTrace();
    }
    // set number of bulk iterations for KMeans algorithm
    IterativeDataSet<GeoTimeDataCenter> loop = centroids.iterate(maxIteration);
    DataSet<GeoTimeDataCenter> newCentroids = points
        // compute closest centroid for each point
        .map(new SelectNearestCenter(this.getBenchmarkCounter())).withBroadcastSet(loop, "centroids")
        // count and sum point coordinates for each centroid
        .groupBy(0).reduceGroup(new CentroidAccumulator())
        // compute new centroids from point counts and coordinate sums
        .map(new CentroidAverager(this.getBenchmarkCounter()));
    // feed new centroids back into next iteration with termination condition
    DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));
    DataSet<Tuple2<Integer, GeoTimeDataTupel>> clusteredPoints = points
        // assign points to final clusters
        .map(new SelectNearestCenter(-1)).withBroadcastSet(finalCentroids, "centroids");
    // emit result
    clusteredPoints.writeAsCsv(outputPath+"/points", "\n", " ");
    finalCentroids.writeAsText(outputPath+"/centers");//print();
    // execute program
    try {
        env.execute("KMeans Flink");
    } catch (Exception e) {
        e.printStackTrace();
    }
}

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}
8nuwlpux

8nuwlpux1#

我认为问题在于filter函数(对你没有发布的代码进行模化)。flink的终止标准的工作方式如下:如果所提供的终止 DataSet 是空的。否则,如果未超过最大迭代次数,则开始下一个迭代。
Flink的 filter 函数只保留 FilterFunction 退货 true . 因此,与您的 MyFilter 实现您只保持迭代前后的质心相同。这意味着您将获得一个空的 DataSet 如果所有质心都改变了,那么迭代就终止了。这显然与实际的终止标准相反。终止标准应该是:只要有一个质心已经改变,就继续k-均值。
你可以用一个 coGroup 函数,如果前面的质心没有匹配的质心,则在其中发射元素 DataSet . 这类似于左外连接,只是丢弃非空匹配。

public static void main(String[] args) throws Exception {
    // set up the execution environment
    final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

    DataSet<Element> oldDS = env.fromElements(new Element(1, "test"), new Element(2, "test"), new Element(3, "foobar"));
    DataSet<Element> newDS = env.fromElements(new Element(1, "test"), new Element(3, "foobar"), new Element(4, "test"));

    DataSet<Element> filtered = newDS.coGroup(oldDS).where("*").equalTo("*").with(new FilterCoGroup());

    filtered.print();
}

public static class FilterCoGroup implements CoGroupFunction<Element, Element, Element> {

    @Override
    public void coGroup(
            Iterable<Element> newElements,
            Iterable<Element> oldElements,
            Collector<Element> collector) throws Exception {

        List<Element> persistedElements = new ArrayList<Element>();

        for(Element element: oldElements) {
            persistedElements.add(element);
        }

        for(Element newElement: newElements) {
            boolean contained = false;

            for(Element oldElement: persistedElements) {
                if(newElement.equals(oldElement)){
                    contained = true;
                }
            }

            if(!contained) {
                collector.collect(newElement);
            }
        }
    }
}

public static class Element implements Key {
    private int id;
    private String name;

    public Element(int id, String name) {
        this.id = id;
        this.name = name;
    }

    public Element() {
        this(-1, "");
    }

    @Override
    public int hashCode() {
        return 31 + 7 * name.hashCode() + 11 * id;
    }

    @Override
    public boolean equals(Object obj) {
        if(obj instanceof Element) {
            Element element = (Element) obj;

            return id == element.id && name.equals(element.name);
        } else {
            return false;
        }
    }

    @Override
    public int compareTo(Object o) {
        if(o instanceof Element) {
            Element element = (Element) o;

            if(id == element.id) {
                return name.compareTo(element.name);
            } else {
                return id - element.id;
            }
        } else {
            throw new RuntimeException("Comparing incompatible types.");
        }
    }

    @Override
    public void write(DataOutputView dataOutputView) throws IOException {
        dataOutputView.writeInt(id);
        dataOutputView.writeUTF(name);
    }

    @Override
    public void read(DataInputView dataInputView) throws IOException {
        id = dataInputView.readInt();
        name = dataInputView.readUTF();
    }

    @Override
    public String toString() {
        return "(" + id + "; " + name + ")";
    }
}

相关问题