In this post, we’ll see how to make a simple transformer for Spark ML Pipelines. The transformer we’ll design will generate a sparse binary feature vector from an array-valued field representing a set.
Preliminaries
The first thing we’ll need to do is expose Spark’s user-defined type for vectors. This will enable us to write a user-defined data frame function that returns a Spark vector. (We could also implement our own user-defined type, but reusing Spark’s, which is currently private to Spark, will save us some time. By the time you read this, the type may be part of Spark’s public API – be sure to double-check!)
package org.apache.spark.hacks {
// make VectorUDT available outside of Spark code
type VectorType = org.apache.spark.mllib.linalg.VectorUDT
}
Imports
Here are the imports we’ll need for the transformer and support code. I’ll use VEC
for Spark vectors to avoid confusion with Scala’s Vector
type. We’ll assume that the VectorType
code from above is available on your project’s classpath.
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.param._
import org.apache.spark.ml.Transformer
import org.apache.spark.mllib.linalg.{Vector => VEC, Vectors}
import org.apache.spark.sql.types._
import org.apache.spark.hacks.VectorType
Transformer and support code
Most of the ML pipeline classes distributed with Spark follow the convention of putting groups of related pipeline stage parameters in a trait. We’ll do this as well, declaring a trait for the three parameters that our transformer will use: the name of the input column, the name of the output column, and the maximum number of elements our sparse vector can hold. We’ll also define a convenience method to return a triple of the parameter values we care about.
trait SVParams extends Params {
val inputCol = new Param[String](this, "inputCol", "input column")
val outputCol = new Param[String](this, "outputCol", "output column")
val vecSize = new IntParam(this, "vecSize", "maximum sparse vector size")
def pvals(pm: ParamMap) = (
// this code can be cleaner in versions of Spark after 1.3
.get(inputCol).getOrElse("topicSet"),
paramMap.get(outputCol).getOrElse("features"),
paramMap.get(vecSize).getOrElse(128)
paramMap)
}
Note that Spark 1.4 supports calling getOrElse
directly on a ParamMap
instance, so you can slightly simplify the code in pvals
if you don’t care about source compatibility with Spark 1.3.
Here’s what the actual transformer implementation looks like:
class SetVectorizer(override val uid: String)
extends Transformer with SVParams {
val VT = new org.apache.spark.hacks.VectorType()
def transformSchema(schema: StructType, params: ParamMap) = {
val outc = paramMap.get(outputCol).getOrElse("features")
StructType(schema.fields ++ Seq(StructField(outc, VT, true)))
}
def transform(df: DataFrame, params: ParamMap) = {
val (inCol, outCol, maxSize) = pvals(paramMap)
.withColumn(outCol, callUDF({ a: Seq[Int] =>
df.sparse(maxSize, a.toArray, Array.fill(a.size)(1.0))
Vectors}, VT, df(inCol)))
}
}
The first thing we do in the transformer class is declare an instance of VectorType
to use in other data frame type declarations later in the class. The transformSchema
method returns the schema after applying this transformer to a given data frame; it creates a new data frame schema that includes all of the fields from the original frame as well as a Vector
-valued field whose name is the parameter specified in the outputCol
parameter. Finally, the transform
method creates a new data frame with an additional column (again, named with the value of the outputCol
parameter); its values result of applying a user-defined function to each row in the data frame, taking arguments from the input column. The function itself simply creates a sparse binary vector from an array-backed set, so that the array-backed set Array(1,2,4,8)
would become a sparse vector with the first, second, fourth, and eighth elements set to 1 and everything else set to 0.
The code above is a reasonable starting point for your own transformers, but you’ll want to add error checking to code you use in production: at a minimum, you’d need to validate the schema of the input data frame (to ensure that expected columns exist and are of the correct type), verify that the output column name doesn’t already exist in the data frame, and make sure no input array has more than vecSize
elements. I hope this code is helpful as you develop your own pipeline stages!