By Michael S. Malak and Robin East

Semi-supervised learning combines the best of both worlds of supervised learning and unsupervised learning. In this article, excerpted from Spark GraphX in Action, we talk about semi-supervised learning.

Semi-supervised learning combines the best of both worlds of supervised learning and unsupervised learning. Although supervised learning has the advantage of predicting human-understandable labels (because it was trained with labeled data), the disadvantage is the time required for a human to label all that training data. That’s very expensive. Because unsupervised learning is trained with unlabeled data, vastly larger training data sets are easier to come by.

The general approach behind semi-supervised learning is to first perform unsupervised learning on the unlabeled data. This will provide some structure that can be applied to the labeled data. Then this enhanced labeled data can be trained using supervised learning to generate more powerful models.

In this article, we handle the situation where we have a bunch of data points in a multi-dimensional space, such as a 2-D plane, a 3-D cube, or higher dimension. The axes of such a space could represent any variable: temperature, test scores, population, etc. The idea is that we’re trying to attach class labels to points in this space under the assumption that similar points will be clustered together.

For example, in the case of a cable TV and Internet provider, a 2-D plane could be constructed with axes for hours of television watched vs. gigabytes of data transferred. We could then distinguish some categories, for example, heavy TV, heavy Internet, and heavy users of both.

By generating a graph to fit these data points, boundaries between these classes of users can be determined. As it turns out, identifying these clusters of similar users will boost the power of our prediction algorithm when we apply it to new data points.

To implement this idea, we’ll first implement a K-Nearest Neighbors graph construction algorithm (not to be confused with the K-Nearest Neighbors algorithm for computing a prediction, which is different and not covered in this book), which will serve as our unsupervised learning piece, and apply it to a data set, the vast majorirty of which is unlabeled. Then we’ll implement a simple label propagation algorithm that propagates the labels to surrounding unlabeled vertices. Finally, we’ll implement a simple knnPredict() function to, given a new data point, predict which class (label) it belongs to.


Many machine learning algorithms are prefixed with the letter k. Generally this refers to a parameter in the model which is conventionally named k and needs to be chosen by the user. The actual meaning of k will depend on each specific algorithm although there are classes of algorithm that will all use k in a similar way. For example in the clustering algorithms k-means and k-medians, k is the number of clusters that we are asking the algorithm to generate. In k-nearest neighbors algorithms we infer something about a point by looking at a number of the most similar points. In this case we have to choose k, the number of most similar points.


Figure 1 Starting condition: a bunch of points in two-dimensional space, almost all of them unlabeled, with the exception of two labeled points.

Figure 1 shows the starting condition, and figure 2 shows what it looks like after both the K-Nearest Neighborhs and semi-supervised learning label propagation algorithms have run. These horseshoe-shaped clusters of data are the classic counterexample for where the K-Means algorithm fails – another type of clustering algorithm (not related at all to K-Nearest-Neighbors). K-Means is focused on finding centroids of clusters and gets confused by long stringy chains of points. But because this approach of using K-Nearest Neighbors for graph construction can follow such chains, it won’t get confused by this type of data.

K-Nearest Neighbors Graph Construction

Spark does not (as of version 1.6) contain an implementation for the K-Nearest Neighbors algorithm. That is the subject of Jira ticket SPARK-2335.

Conceptually, finding the k nearest neighbors is trivially simple. For every point, find its k nearest neighbors out of all the other points and extend edges to those points. And in fact, this naïve brute-force approach is shown in listing 1.


Figure 2 After both the K-Nearest Neighbors graph construction algorithm and the semi-supervised learning label propagation algorithm have run.

Listing 1 Brute force K Nearest Neighbors

import org.apache.spark.graphx._
 case class knnVertex(classNum:Option[Int],
                      pos:Array[Double]) extends Serializable {
   def dist(that:knnVertex) = math.sqrt( => (x._1-x._2)*(x._1-x._2)).reduce(_ + _))
 def knnGraph(a:Seq[knnVertex], k:Int) = {
   val a2 = => (x._2.toLong, x._1)).toArray
   val v = sc.makeRDD(a2)
   val e = => (v1._1, => (v2._1, v1._2.dist(v2._2)))
                                 .sortWith((e,f) => e._2 < f._2)
            .flatMap(x => =>
              Edge(x._1, vid2,
                   1 / (1+a2(vid2.toInt)._2.dist(a2(x._1.toInt)._2)))))

The problem is performance. For each of the n points, n distances have to be computed and then these n distances have to be sorted, at a cost of n log n. So that’s n2 log n. All those various K Nearest Neighbor algorithms out there are attempting to solve the problem more efficiently. But because it’s a non-polynomial problem (that is the no algorithm can be constructed to solve it in any reasonable amount of time), they all come up with approximate solutions. We’ll look at such an approximate approach, and one suited to Spark’s distributed processing, later in this subsection.

But first, let’s look at listing 1. If you’d like to try it out, listing 2 will generate the data shown in figure 1, and listing 3 is a special export to Gephi .gexf file format tailored to our knnVertex that outputs color and position tags. Listing 4 executes the algorithm and the export to .gexf. Here we choose k=4 for K-Nearest Neighborhood. 3 or 4 are typical values for k.

Listing 2 Generate example data

import scala.util.Random
 val n = 10
 val a = (1 to n*2).map(i => {
   val x = Random.nextDouble;
   if (i <= n)
     knnVertex(if (i % n == 0) Some(0) else None, Array(x*50,
       20 + (math.sin(x*math.Pi) + Random.nextDouble / 2) * 25))
     knnVertex(if (i % n == 0) Some(1) else None, Array(x*50 + 25,
       30 - (math.sin(x*math.Pi) + Random.nextDouble / 2) * 25))

Listing 3 Custom export (with layout) to Gephi .gexf for knnVertex-based graphs

import java.awt.Color
 def toGexfWithViz(g:Graph[knnVertex,Double], scale:Double) = {
   val colors = Array(,, Color.yellow,,
                      Color.magenta,, Color.darkGray)
   "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" +
   "<gexf xmlns=\"\" " +
         "xmlns:viz=\"\" " +
         "version=\"1.2\">\n" +
   "  <graph mode=\"static\" defaultedgetype=\"directed\">\n" +
   "    <nodes>\n" + =>
     "      <node id=\"" + v._1 + "\" label=\"" + v._1 + "\">\n" +
     "        <viz:position x=\"" + v._2.pos(0) * scale +
               "\" y=\"" + v._2.pos(1) * scale + "\" />\n" +
     (if (v._2.classNum.isDefined)
        "        <viz:color r=\"" + colors(v._2.classNum.get).getRed +
                  "\" g=\"" + colors(v._2.classNum.get).getGreen +
                  "\" b=\"" + colors(v._2.classNum.get).getBlue + "\" />\n"
      else "") +
     "      </node>\n").collect.mkString +
   "    </nodes>\n" +
   "    <edges>\n" + => "      <edge source=\"" + e.srcId +
                    "\" target=\"" + e.dstId + "\" label=\"" + e.attr +
                    "\" />\n").collect.mkString +
   "    </edges>\n" +
   "  </graph>\n" +

Listing 4 Execute K Nearest Neighbors on the example data and export to .gexf

val g = knnGraph(a, 4)
 val pw = new"knn.gexf")

The core of listing 1 is in the computation of e, which is an RDD of Edges that we pass into Graph() to create the return value graph. And within this computation, we can see the n2 nature of the computation. There is an outer (performed on v, the RDD of vertices) and an inner (performed on a2, the Array version of v). For each vertex we compute and sort all the distances, and pick off the k with shortest distance (we ignore index 0, because that’s just the same vertex as itself with distance zero). When we construct the actual Edge() at the end, we use the edge attribute to store the inverse of the distance. This will be used in the semi-supervised learning label propagation but isn’t needed for K Nearest Neighbors itself.

The use of flatMap() in listing 1 is a non-trivial use of it: it is effectively doing both a map() (to transform each collection of distant vertices into a collection of Edges)  and a flatten() (to make a single collection of Edges out of the collection of collections of Edges).

Toward a distributed K Nearest Neighbors algorithm

Of the various approximate K Nearest Neighbors graph construction algorithms out there, most are geared toward conventional serial processing rather than distributed parallel processing. A notable exception that does do distributed processing is from the 2012 Microsoft Research paper “Scalable k-NN graph construction for visual descriptors” by Wang, et al.

That paper includes a lot of optimizations for distributed computing, but here we’ll just take and implement one of their many ideas and ignore the rest. It won’t in the slightest do their paper justice, but it’ll put us on a path toward a more practical K Nearest Neighbors implementation for Spark.

The first key insight from the Wang paper, and the only one we adapt, is to break the space up into grids, and perform the brute-force K-Nearest Neighbor graph construction algorithm on each cell in the grid. In figure 3, the space is variously broken up into 3×3 grids (in the first of the two dividings, the last grids have zero width or height). If we then say m=3, then the complexity is cm2(n/m2)(n/m2)log(n/m2) = c(n/m)2log(n/m2) where c is the number of different grids we use.


Figure 3 Distributed K-Nearest Neighbor graph construction. Divide the space into grids, and perform brute-force K-Nearaest Neighbor graph construction within each grid cell. To avoid missing edges that would cross a cell boundary, vary the grid and run again, and take the union of the two edge sets.

This is the simple approach we will take. Again, the full algorithm described in the Wang paper is much more sophisticated, such as its use of Principal Component Analysis (PCA) to determine the grid orientation (breaking the space up into parallelograms instead of squares), using many random grids, and coalescing directed edges into undirected edges.

For implementing in Spark, we want to map each grid cell onto a separate executor (task). This can be done by paying attention to Spark RDD partitioning.


RDD Partitioning (and mapPartitions())

At a fundamental level, RDDs are distributed data sets, and how Spark decides to distribute the data amongst the nodes in the cluster depends on the RDD’s partitioner. The default partitioner is the HashPartitioner, which hashes the key in Tuple2[K,V] key/value pairs, thus sending RDD data elements with equal keys to the same node. This, of course, assumes the RDD is composed of key/value pairs in the first place (like those that PairRDDFunctions operate on). But if it’s a plain old RDD with no keys, then Spark just makes up random keys before running it through the HashPartitioner.

As described in section 9.3, GraphX adds another layer of abstraction to partitioning. But under the covers, GraphX is just controlling partitioning via HashPartitioner and setting hidden keys to ensure HashPartitioner puts RDD elements where it wants it to.

But when making a copy of the edges RDD or vertices RDD – recall that RDDs are immutable and any operations done on them make copies – we may opt to apply our own partitioning for performance or algorithmic purposes. A convenient side effect of groupByKey() is that data is shuffled and repartitioned by key. This can sometimes obviate having to create a custom partitioner, which involves subclassing Partitioner and overriding member functions.

Partitioning is something that happens behind the scenes, and we normally do not need to worry about. If, however, we want to specify exactly where data goes, either for performance or algorithmic purposes, then we need to pay attention. An important means of making good use of partitions is through the mapPartitions() function.

mapPartitions() lets you deal with all the data in a partition in the form of a Scala collection. This lets you do any expensive set-up and tear-down – such as creating a database cursor or instantiating and initializing a parser object – once in each executor before that executor goes to work on its portion of the RDD. If you were to try to do this using RDD’s plain old map(), that expensive operation would be done once per data element instead of once per partition.

In listing 4, we use groupByKey() to shuffle the data to partitions as shown in figure 3, and then we use mapPartitions() to do the brute-force K Nearest Neighborhood edge generation within each cell. mapPartitions() allows us to capture up front (into the variable af) that full subset of vertices inside that grid cell – say there are d vertices in the grid cell – and then compute the d2 distances and complete the Nearest Neighborhood edge generation.

The result of executing this approximate K Nearest Neighborhood graph generation algorithm on the example data, followed by executing the semi-supervised learning label propagation algorithm described in the next subsection, is shown in figure 4.

Listing 4 Distributed, approximate K Nearest Neighborhood graph generation

def knnGraphApprox(a:Seq[knnVertex], k:Int) = {
   val a2 = => (x._2.toLong, x._1)).toArray
   val v = sc.makeRDD(a2)   val n = 3
   val minMax = => (x._2.pos(0), x._2.pos(0), x._2.pos(1), x._2.pos(1)))
      .reduce((a,b) => (math.min(a._1,b._1), math.max(a._2,b._2),
                        math.min(a._3,b._3), math.max(a._4,b._4)))
   val xRange = minMax._2 - minMax._1
   val yRange = minMax._4 - minMax._3
     def calcEdges(offset: Double) = => (math.floor((x._2.pos(0) - minMax._1)
                            / xRange * (n-1) + offset) * n
                   + math.floor((x._2.pos(1) - minMax._3)
                                / yRange * (n-1) + offset),
      .mapPartitions(ap => {
        val af = ap.flatMap(_._2).toList => (v1._1, => (v2._1, v1._2.dist(v2._2)))  
                               .sortWith((e,f) => e._2 < f._2)
             .flatMap(x => => Edge(x._1, vid2, 
                1 / (1+a2(vid2.toInt)._2.dist(a2(x._1.toInt)._2)))))
   val e = calcEdges(0.0).union(calcEdges(0.5))
                         .map(x => (x.srcId,x))
                         .map(x => x._2.toArray
                                    .sortWith((e,f) => e.attr > f.attr)
                         .flatMap(x => x)


Spark tip

The RDD function union(), unlike SQL UNION, does not eliminate duplicates. You have to call distinct() right afterward if you want unique values in your resultant RDD.

Note that in order to get the groupByKey() to actually partition and shuffle the way we expect it to, we had to use its optional parameter to specify the number of partitions. If we didn’t, then groupByKey() might combine some of the small partitions into one if they are small. Since it affects our algorithm, we want to in this case insist on the larger number of partitions. We specify the maximum it could be (n*n), and if there happen to be fewer (if some of the grid cells are empty), groupByKey() will simply use as many partitions as keys that actually exist.

Also note that due to the above, the parameter passed into the function we supply to mapPartitions() is technically not for just a single key. It’s a collection containing multiple keys. Because we assume we’re just getting a single key for the partition, we start off with a flatMap() on that parameter to eliminate that extra level of nesting.


Figure 4 Result of executing both the approximate distributed K Nearest Neighborhood algorithm and the semi-supervised learning label propagation algorithm from the next section. There are only about two-thirds as many edges in this one.

When we calculate e at the end, the set of edges, we union() the two sets of edges from the two possible grids shown in figure 3. Because that may result in more than k edges for any given vertex, we trim that list down with the groupByKey(), map(), flatMap() sequence of function calls. The comparator in the sortWith() is a greater-than rather than the usual less-than because the edge attributes are the recipricols of the distances rather than the distances themselves.

Semi-Supervised Learning Label Propagation

What we have done so far is to extract some structure from all the points in our data set without worrying about  what label we are going to apply to them. Figure 5 shows the structure we have built up. Our 2 labeled points are coloured (red and blue) but most of our points have no labels associated with them and remain grey. We resolve this now by implementing a label propagation algorithm to assign a label to all those grey vertices. We then show how this fully-labeled model can be used to predict the label for new unlabeled data point.


Figure 5 The graph that results from applying our distributed, approximate unsupervised learning algorithm. Now there is structure but unlabeled vertices remain unlabeled.

Now we’re ready to implement the label propagation. Spark’s built-in label propagation algorithm takes a dataset of already labeled vertices and attempts to identify and label communities through a label consensus process.

By contrast, what we present in this article is a means of propagating labels from a few known labeled vertices to a much larger selection of unlabeled vertices by using the graph structure built by our unsupervised learning. It also takes into account edge distances, weighting nearby vertices more heavily. The result is an algorithm that almost always converges.

The algorithm can be described as:

  1. For each edge emanating from a labeled vertex, send that vertex’s label together with the edge weight (that is, the reciprocal of the edge length) to both the source and destination of the edge.
  2. For each vertex, add up the scores on by-class (by-label) basis. If the vertex is not one of the vertices with a pre-known, fixed label, then assign the winning class (label) to the vertex.
  3. If no vertices changed labels, or if maxIterations is reached, then terminate.

We use aggregateMessages() (together with joinVertices()) rather than Pregel() because the terminating condition in Pregel() is when no messages are sent any longer. Here, we always send a labeled vertex’s label back to itself to ensure that permanently labeled vertices can retain their label. So we can’t use Pregel() in this code.

Note that the gist of this algorithm is that it treats the graph as an undirected graph. The actual implementation treats source and destination slightly differently in its attempt to ensure that permanently labeled vertices never switch their label, but conceptually labels can travel in either direction along the edge.

Figure 6 illustrates iteration by iteration the application of this algorithm to the perfect K Nearest Neighborhood graph from figures 1 and 2.


Scala Tip

The operator -> is shorthand for establishing a key/value pair in a Scala HashMap. For those familiar with PHP, this is similar to PHP’s array initialization using =>.

Listing 5 Semi-Supervised Learning Label Propagation

import scala.collection.mutable.HashMap
 def semiSupervisedLabelPropagation(g:Graph[knnVertex,Double],                                    maxIterations:Int = 0) = {
   val maxIter = if (maxIterations == 0) g.vertices.count / 2
                 else maxIterations
   var g2 = g.mapVertices((vid,vd) => (vd.classNum.isDefined, vd))
   var isChanged = true
   var i = 0
     do {
     val newV =
         ctx => {
                          if (ctx.dstAttr._2.classNum.isDefined)
           if (ctx.srcAttr._2.classNum.isDefined)
         (a1, a2) => {
           if (a1._1.isDefined)
             (a1._1, HashMap[Int,Double]())
           else if (a2._1.isDefined)
             (a2._1, HashMap[Int,Double]())
             (None, a1._2 ++{
               case (k,v) => k -> (v + a1._2.getOrElse(k,0.0)) })
     val newVClassVoted = => (x._1,
       if (x._2._1.isDefined)
       else if (x._2._2.size > 0)
         Some(x._2._2.toArray.sortWith((a,b) => a._2 > b._2)(0)._1)
       else None
     isChanged = g2.vertices.join(newVClassVoted)
                            .map(x => x._2._1._2.classNum != x._2._2)
                            .reduce(_ || _)
     g2 = g2.joinVertices(newVClassVoted)((vid, vd1, u) =>
       (vd1._1, knnVertex(u, vd1._2.pos)))     
     i += 1
   } while (i < maxIter && isChanged)
   g2.mapVertices((vid,vd) => vd._2)



Figure 6 Iterations of semi-supervised learning label propagation applied to the perfect K Nearest Neighbors example.


Now that the graph is trained up for semi-supervised learning, we can use it now to “predict” labels. That is, given a point with (x,y) coordinates, to which class (label) does it belong? Listing 6 is code for a dead-simple prediction function. It simply finds the closest labeled vertex (regardless of whether it was originally labeled, or it got its label as a result of the propagation) and returns that value. Technically, this is implementing k-nearest neighbors prediction (not to be confused with k-nearest neighbors graph construction) with k=1.

Listing 6 Prediction function to use the semi-supervised learned graph

def knnPredict[E](g:Graph[knnVertex,E],pos:Array[Double]) =
    .map(x => (x._2.classNum.get, x._2.dist(knnVertex(None,pos))))
    .min()(new Ordering[Tuple2[Int,Double]] {
      override def compare(a:Tuple2[Int,Double],
                           b:Tuple2[Int,Double]): Int =


Listing 7 Execute semi-supervised learning label propagation and then use it to predict a class (label) for a particular (x,y) coordinate of (30.0,30.0)

val gs = semiSupervisedLabelPropagation(g)
 knnPredict(gs, Array(30.0,30.0))
 res5: Int = 0



Neither GraphX nor MLlib have semi-supervised learning built in, but an example of semi-supervised learning can be achieved via a combination of K-Nearest Neighbors Graph Construction and an intuitive label propagation.