Apache flink - filter as termination condition

I have defined a filter for the termination condition using k-means. if i run the app it always only calculates one iteration.

I think the problem is here:

DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));

      

or maybe a filter function:

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}

      

Best wishes Paul

my complete code is here:

public void run() {   
    //load properties
    Properties pro = new Properties();
    FileSystem fs = null;
    try {
        pro.load(FlinkMain.class.getResourceAsStream("/config.properties"));
        fs = FileSystem.get(new URI(pro.getProperty("hdfs.namenode")),new org.apache.hadoop.conf.Configuration());
    } catch (Exception e) {
        e.printStackTrace();
    }

    int maxIteration = Integer.parseInt(pro.getProperty("maxiterations"));
    String outputPath = fs.getHomeDirectory()+pro.getProperty("flink.output");
    // set up execution environment
    ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    // get input points
    DataSet<GeoTimeDataTupel> points = getPointDataSet(env);
    DataSet<GeoTimeDataCenter> centroids = null;
    try {
        centroids = getCentroidDataSet(env);
    } catch (Exception e1) {
        e1.printStackTrace();
    }
    // set number of bulk iterations for KMeans algorithm
    IterativeDataSet<GeoTimeDataCenter> loop = centroids.iterate(maxIteration);
    DataSet<GeoTimeDataCenter> newCentroids = points
        // compute closest centroid for each point
        .map(new SelectNearestCenter(this.getBenchmarkCounter())).withBroadcastSet(loop, "centroids")
        // count and sum point coordinates for each centroid
        .groupBy(0).reduceGroup(new CentroidAccumulator())
        // compute new centroids from point counts and coordinate sums
        .map(new CentroidAverager(this.getBenchmarkCounter()));
    // feed new centroids back into next iteration with termination condition
    DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));
    DataSet<Tuple2<Integer, GeoTimeDataTupel>> clusteredPoints = points
        // assign points to final clusters
        .map(new SelectNearestCenter(-1)).withBroadcastSet(finalCentroids, "centroids");
    // emit result
    clusteredPoints.writeAsCsv(outputPath+"/points", "\n", " ");
    finalCentroids.writeAsText(outputPath+"/centers");//print();
    // execute program
    try {
        env.execute("KMeans Flink");
    } catch (Exception e) {
        e.printStackTrace();
    }
}

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}

      

+3


source to share


1 answer


I think the problem is the filter function (modulo the code you haven't posted yet). The fling completion criterion works as follows: The completion criterion is met if the supplied completion is DataSet

empty. Otherwise, the next iteration is started if the maximum number of iterations has not been exceeded.

Flink filter

only stores the items it FilterFunction

returns for true

. This way, in implementation, MyFilter

you only keep the centroids before and after the iteration. This means that you will get empty DataSet

if all centroids have changed and thus the iteration will end. This is obviously the opposite of the true termination criterion. The termination criterion should be: Continue with k-means if the centroid has changed.



You can do this with a function coGroup

where you emit elements if there is no corresponding centroid from the previous center of gravity DataSet

. It's like a left outer join, you just throw away non-empty matches.

public static void main(String[] args) throws Exception {
    // set up the execution environment
    final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

    DataSet<Element> oldDS = env.fromElements(new Element(1, "test"), new Element(2, "test"), new Element(3, "foobar"));
    DataSet<Element> newDS = env.fromElements(new Element(1, "test"), new Element(3, "foobar"), new Element(4, "test"));

    DataSet<Element> filtered = newDS.coGroup(oldDS).where("*").equalTo("*").with(new FilterCoGroup());

    filtered.print();
}

public static class FilterCoGroup implements CoGroupFunction<Element, Element, Element> {

    @Override
    public void coGroup(
            Iterable<Element> newElements,
            Iterable<Element> oldElements,
            Collector<Element> collector) throws Exception {

        List<Element> persistedElements = new ArrayList<Element>();

        for(Element element: oldElements) {
            persistedElements.add(element);
        }

        for(Element newElement: newElements) {
            boolean contained = false;

            for(Element oldElement: persistedElements) {
                if(newElement.equals(oldElement)){
                    contained = true;
                }
            }

            if(!contained) {
                collector.collect(newElement);
            }
        }
    }
}

public static class Element implements Key {
    private int id;
    private String name;

    public Element(int id, String name) {
        this.id = id;
        this.name = name;
    }

    public Element() {
        this(-1, "");
    }

    @Override
    public int hashCode() {
        return 31 + 7 * name.hashCode() + 11 * id;
    }

    @Override
    public boolean equals(Object obj) {
        if(obj instanceof Element) {
            Element element = (Element) obj;

            return id == element.id && name.equals(element.name);
        } else {
            return false;
        }
    }

    @Override
    public int compareTo(Object o) {
        if(o instanceof Element) {
            Element element = (Element) o;


            if(id == element.id) {
                return name.compareTo(element.name);
            } else {
                return id - element.id;
            }
        } else {
            throw new RuntimeException("Comparing incompatible types.");
        }
    }

    @Override
    public void write(DataOutputView dataOutputView) throws IOException {
        dataOutputView.writeInt(id);
        dataOutputView.writeUTF(name);
    }

    @Override
    public void read(DataInputView dataInputView) throws IOException {
        id = dataInputView.readInt();
        name = dataInputView.readUTF();
    }

    @Override
    public String toString() {
        return "(" + id + "; " + name + ")";
    }
}

      

+4


source







All Articles