This brief post is based on material that Erik and I didn’t have time to cover in our Spark+AI Summit talk; it will show you how to use Scala’s implicit parameter mechanism to work around an aspect of the RDD API that can make it difficult to write generic functions. This post will be especially useful for experienced Spark users who are relatively new to Scala.
If you’ve written reusable code that uses Spark’s RDD API, you might have run into headaches related to variance. The RDD is an invariant API, meaning that RDD[T]
and RDD[U]
are unrelated types if T
and U
are different types – even if there is a subtyping relation between T
and U
.
Let’s say you had a Scala trait and some concrete class extending that trait, like these:
trait HasUserId { val userid: Int }
case class Transaction(override val userid: Int,
: Int,
timestamp: Double)
amountextends HasUserId {}
You might then want to write a function operating on an RDD of any type that is a subtype of your HasUserId
trait, like this:
def badKeyByUserId(r: RDD[HasUserId]) = r.map(x => (x.userid, x))
Unfortunately, this code isn’t that useful, because RDDs are invariant. Let’s apply it to a concrete RDD of some type that is a subtype of HasUserId
:
val xacts = spark.parallelize(Array(
Transaction(1, 1, 1.0),
Transaction(2, 2, 1.0)
))
badKeyByUserId(xacts)
This will fail to compile due to the type mismatch: we’ve supplied an org.apache.spark.rdd.RDD[Transaction]
but the function required an org.apache.spark.rdd.RDD[HasUserId]
. Since there is no subtyping relation between these two, we cannot supply the former in place of the latter. We could explicitly cast our RDD or its elements and get our code to compile and run:
/* cast the collection all at once */
badKeyByUserId(xacts.asInstanceOf[RDD[HasUserId]])
/* or cast each element */
badKeyByUserId(xacts.map(x => x.asInstanceOf[HasUserId]))
Explicit casts are clunky, though, and they also cost us precision: once we’ve cast up to RDD[(Int, HasUserId)]
, we have no safe way to get back to an RDD[(Int, Transaction)]
.
A better approach is to use Scala’s generic types in conjunction with implicit parameters to write a generic function that only accepts RDDs of some concrete type that is a subtype of HasUserId
, like this:
def keyByUserId[T: ClassTag](r: RDD[T])(implicit bid: T => HasUserId) =
.map(x => (bid(x).userid, x)) r
Let’s walk through what’s happening here. When we invoke keyByUserId
with an RDD of some type T
, the Scala compiler will first make sure there is a function in scope mapping from T
to HasUserId
.1 Put another way, the implicit formal parameter imposes a constraint on T
– if there is a function that supplies evidence that T
satisfies the constraint, the code will compile. This function will exist for any concrete subtype of HasUserId
. We’ll then use that function to get a HasUserId
-typed reference for each element of the collection so we can safely access the userid
field. We’ll not only be able to apply that function to an RDD of Transaction
objects, but it will return a result with a specific type: RDD[(Int, Transaction)]
.
It’s worth noting that we could also define a conversion from instances of some type unrelated to HasUserId
to instances of HasUserId
, meaning we aren’t restricted by the subtyping relationship. You can see a similar approach in action in my explanation of implementing database-style type translation in Scala’s type system.
It should be clear that using generics in this way can capture most of what we’d like to capture with a covariant collection (that is, a collection C
such that C[T] <: C[U]
iff T <: U
). However, the general technique is more powerful than simply simulating covariance: what we’re doing here is using Scala’s implicit resolution to implement typeclasses so we can support typesafe ad hoc polymorphism. To see an example of how this affords us additional flexibility, let’s look at a generic method operating on RDDs of numeric values:
def multiply[T: ClassTag](r: RDD[T], v: T)(implicit ev: Numeric[T]) =
.map(x => ev.times(x, v))
r
multiply(spark.parallelize(Array(1,2,3,4)),4).collect()
// => Array(4, 8, 12, 16)
multiply(spark.parallelize(Array(1.0,2.0,3.0,4.0)),4.0).collect()
// => Array(4.0, 8.0, 12.0, 16.0)
As you can see, the same multiply
method works for integers and doubles; indeed, it will work on any of Scala’s numeric types as well as any type T
for which you define an implicit instance of Numeric[T]
.
In conclusion, the RDD is invariant, but you can still do useful generic programming with it as long as you’re willing to use Scala’s implcit conversions.
Footnotes
It is also possible to supply one explicitly, in case there are several possible options. We can use
implicitly
to simulate the Scala compiler’s implicit resolution, so we could invoke our function the way that the Scala compiler does like this:keyByUserId(xacts)(implicitly[Transaction => HasUserId])
↩︎