package org.ua.ap; import java.util.HashMap; import java.util.Map.Entry; import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.graph.Edge; import org.apache.flink.graph.EdgeDirection; import org.apache.flink.graph.Graph; import org.apache.flink.graph.GraphAlgorithm; import org.apache.flink.graph.NeighborsFunctionWithVertexValue; import org.apache.flink.graph.Vertex; import org.apache.flink.graph.spargel.MessageIterator; import org.apache.flink.graph.spargel.MessagingFunction; import org.apache.flink.graph.spargel.VertexCentricIteration; import org.apache.flink.graph.spargel.VertexUpdateFunction; @SuppressWarnings({ "serial", "hiding" }) public class AffinityPropogation implements GraphAlgorithm { private int maxIterations; private Double lambda; public AffinityPropogation(int maxIterations, double lambda){ this.maxIterations = maxIterations; this.lambda = lambda; } /** * For each vertex v, the value is a hashmap including [k': s(v,k'), r(v, k'), a(k',v)] * @param originalGraph * @return */ private Graph>, Double> transferGraph(Graph originalGraph){ /*Reconstruct a graph*/ DataSet> vertices = originalGraph.getVertices(); DataSet> edges = originalGraph.getEdges(); /*Transfer the vertex value to a hash map, which stores the to the neigbor vertices */ DataSet>>> newVertices = vertices.map( new MapFunction,Vertex>>>(){ @Override public Vertex>> map( Vertex v) throws Exception { // TODO Auto-generated method stub HashMap> hashmap = new HashMap>(); hashmap.put(v.getId(),new Tuple3(0.0, 0.0, 0.0)); return new Vertex>>(v.getId(), hashmap); } }); /*Construct a new graph with HashMap> value for each vertex*/ Graph>, Double> newGraph = Graph.fromDataSet(newVertices, edges, ExecutionEnvironment.getExecutionEnvironment()); /*Create a vertex centric iteration to construct the hash map for each vertex*/ // VertexCentricIteration>, Tuple2, Double> iteration = // newGraph.createVertexCentricIteration(new InitialVertexUpdater(), new InitialMessenger(), 1); /*run iteration*/ Graph>, Double> hashMappedGraph = newGraph.runVertexCentricIteration(new InitialVertexUpdater(), new InitialMessenger(), 1); /*Eliminate the self edges*/ Graph>, Double> finalGraph = hashMappedGraph.filterOnEdges(new FilterFunction>(){ @Override public boolean filter(Edge edge) throws Exception { return edge.f0 != edge.f1; } }); return finalGraph; } public static final class InitialVertexUpdater extends VertexUpdateFunction>, Tuple2>{ @Override public void updateVertex(Long vertexKey, HashMap> vertexValue, MessageIterator> inMessages) throws Exception { for (Tuple2 message: inMessages){ if (!vertexValue.containsValue(message.f0)){ vertexValue.put(message.f0, new Tuple3(message.f1, 0.0, 0.0)); }else{ vertexValue.get(message.f0).f0 = message.f1; } } setNewVertexValue(vertexValue); } } public static final class InitialMessenger extends MessagingFunction>, Tuple2, Double>{ @Override public void sendMessages(Long vertexKey, HashMap> vertexValue) throws Exception { for (Edge e: getOutgoingEdges()){ sendMessageTo(e.getTarget(), new Tuple2(vertexKey, e.getValue())); } } } @Override public Graph run(Graph input) throws Exception { /*Transfer the vertex value to HashMap>*/ Graph>, Double> graph = transferGraph(input); // create the vertex-centric iteration // VertexCentricIteration>, Tuple2, Double> iteration = // graph.createVertexCentricIteration(new VertexUpdater(lambda), new InformationMessenger(), 2*maxIterations); //Run the vertex centric iteration Graph>, Double> stableGraph = graph.runVertexCentricIteration( new VertexUpdater(lambda), new InformationMessenger(), maxIterations*1); DataSet> resultSet = stableGraph.reduceOnNeighbors(new NeighborSelection(), EdgeDirection.OUT); // Graph resultGraph = Graph.fromDataSet(resultSet, input.getEdges(), ExecutionEnvironment.getExecutionEnvironment()); return resultGraph; } public static final class NeighborSelection implements NeighborsFunctionWithVertexValue>, Double, Vertex>{ @Override public Vertex iterateNeighbors( Vertex>> vertex, Iterable, Vertex>>>> neighbors) throws Exception { // TODO Auto-generated method stub /*Get Evidence*/ HashMap> hmap = vertex.getValue(); Double selfEvidence = hmap.get(vertex.getId()).f1 + hmap.get(vertex.getId()).f2; if (selfEvidence>0){ return new Vertex(vertex.getId(), vertex.getId()); }else{ Double maxSimilarity = Double.NEGATIVE_INFINITY; Long belongExemplar = vertex.getId(); for (Tuple2, Vertex>>> neigbor: neighbors){ Long neigborId = neigbor.f1.getId(); HashMap> neigborMap = neigbor.f1.getValue(); Double neigborEvidence = neigborMap.get(neigborId).f1 + neigborMap.get(neigborId).f2; if (neigborEvidence > 0 ){ Double neigborSimilarity = vertex.getValue().get(neigborId).f1; if (neigborSimilarity > maxSimilarity){ belongExemplar = neigborId; maxSimilarity = neigborEvidence; } } } return new Vertex(vertex.getId(), belongExemplar); } } } /***************************************************************************************************/ /*Update r(i,k) for each vertex*/ public static final class VertexUpdater extends VertexUpdateFunction>, Tuple2>{ private Double lambda; public VertexUpdater(Double lambda){ this.lambda = lambda; } @Override public void updateVertex(Long vertexKey, HashMap> vertexValue, MessageIterator> inMessages) throws Exception { int step = getSuperstepNumber(); if (step % 2 == 1){ /*Start from odd step: receive neighbor responsibility and update the responsibility*/ Double selfSum = vertexValue.get(vertexKey).f0 + vertexValue.get(vertexKey).f2; /*Find the max a(v, k) + s(v, k)*/ Double maxSum = selfSum; Long maxKey = vertexKey; Double secondMaxSum = Double.NEGATIVE_INFINITY; int msgCnt = 0; for (Tuple2 msg: inMessages){ Long adjacentVertex = msg.f0; Double sum = msg.f1 + vertexValue.get(adjacentVertex).f0; if (sum > maxSum ){ secondMaxSum = maxSum; maxSum = sum; maxKey = adjacentVertex; }else if(sum > secondMaxSum){ secondMaxSum = sum; } msgCnt++; // System.out.format("Evaluate %d-%d %f\n",msg.f0,vertexKey, sum); } //DEBUG System.err.format("Super %d, msg count %d\n", step, msgCnt); if (maxKey != vertexKey && selfSum> secondMaxSum){ secondMaxSum = selfSum; } /*Find the second max*/ if (maxKey != vertexKey) secondMaxSum = vertexValue.get(vertexKey).f0 + vertexValue.get(vertexKey).f2; /*Update responsibility*/ for (Entry> entry: vertexValue.entrySet()){ Double newRespons = 0.0; if (entry.getKey() != maxKey){ newRespons = entry.getValue().f0 - maxSum; }else{ newRespons = entry.getValue().f0 - secondMaxSum; } entry.getValue().f1 = (1 - lambda) * newRespons + lambda * entry.getValue().f1; // System.out.format("Super step%d " // + "Update response %d: %d(%f)\n", getSuperstepNumber(), // vertexKey, entry.getKey(), entry.getValue().f1); } /*reset the hashmap of the vertex*/ setNewVertexValue(vertexValue); }else{ /*Odd step: receive responsibility and update availability */ double sum = 0.0; for (Tuple2 msg: inMessages){ sum += (msg.f1 > 0? msg.f1: 0); } Double selfRespons = vertexValue.get(vertexKey).f1; Double newAvailability = 0.0; for (Tuple2 msg: inMessages){ Double a = selfRespons + sum - msg.f1; newAvailability = a < 0? a: 0; vertexValue.get(msg.f0).f2 = (1 - lambda) * newAvailability + lambda * vertexValue.get(msg.f0).f2; } /*update self-availability*/ vertexValue.get(vertexKey).f2 = (1 - lambda) * sum + lambda *vertexValue.get(vertexKey).f2; /*reset the hash map*/ setNewVertexValue(vertexValue); } } } /*Send a(k',i) to k'*/ public static final class InformationMessenger extends MessagingFunction>, Tuple2, Double>{ @Override public void sendMessages(Long vertexKey, HashMap> vertexValue) throws Exception { if (getSuperstepNumber() % 2 == 1){ /*Odd step: Propagate availability*/ int msgCnt = 0; for (Edge e: getOutgoingEdges()){ Long dest = e.getTarget(); sendMessageTo(dest, new Tuple2(vertexKey, vertexValue.get(dest).f2)); msgCnt++; } System.err.format("Step %d: Vertex %d send %d\n", getSuperstepNumber(), vertexKey, msgCnt); }else{ /*Even step: propagate responsibility*/ for (Edge e: getOutgoingEdges()){ Long dest = e.getTarget(); sendMessageTo(dest, new Tuple2(vertexKey, vertexValue.get(dest).f1)); } } } } }