Natural join for data frames in Spark

spark
sql
data frames
Published

April 8, 2015

Natural join is a useful special case of the relational join operation (and is extremely common when denormalizing data pulled in from a relational database). Spark’s DataFrame API provides an expressive way to specify arbitrary joins, but it would be nice to have some machinery to make the simple case of natural join as easy as possible. Here’s what a natural join needs to do:

  1. For relations R and S, identify the columns they have in common, say c1 and c2;
  2. join R and S on the condition that R.c1 == S.c1 and R.c2 == S.c2; and
  3. project away the duplicated columns.

so, in Spark, a natural join would look like this:

/* r and s are DataFrames, declared elsewhere */

val joined = r.join(s, r("c1") == s("c1") && r("c2") == s("c2"))
val common = Set("c1", "c2")
val outputColumns = Seq(r("c1"), r("c2")) ++ 
                    r.columns.collect { case c if !common.contains(c) => r(c) } ++
                    s.columns.collect { case c if !common.contains(c) => s(c) }
val projected = joined.select(outputColumns : _*)

We can generalize this as follows (note that joining two frames with no columns in common will produce an empty frame):

import org.apache.spark.sql.DataFrame
import scala.language.implicitConversions

trait NaturalJoining {
  import org.apache.spark.sql.functions._
  import org.apache.spark.sql.types._  
  
  /**
   * Performs a natural join of two data frames.
   *
   * The frames are joined by equality on all of the columns they have in common.
   * The resulting frame has the common columns (in the order they appeared in <code>left</code>), 
   * followed by the columns that only exist in <code>left</code>, followed by the columns that 
   * only exist in <code>right</code>.
   */
  def natjoin(left: DataFrame, right: DataFrame): DataFrame = {
    val leftCols = left.columns
    val rightCols = right.columns

    val commonCols = leftCols.toSet intersect rightCols.toSet
    
    if(commonCols.isEmpty)
      left.limit(0).join(right.limit(0))
    else
      left
        .join(right, commonCols.map {col => left(col) === right(col) }.reduce(_ && _))
        .select(leftCols.collect { case c if commonCols.contains(c) => left(c) } ++ 
                leftCols.collect { case c if !commonCols.contains(c) => left(c) } ++ 
                rightCols.collect { case c if !commonCols.contains(c) => right(c) } : _*)
  }
}

Furthermore, we can make this operation available to any DataFrame via implicit conversions:

case class DFWithNatJoin(df: DataFrame) extends NaturalJoining {
  def natjoin(other: DataFrame): DataFrame = super.natjoin(df, other)
}

/** 
 * Module for natural join functionality.  Import <code>NaturalJoin._</code> for static access 
 * to the <code>natjoin</code> method, or import <code>NaturalJoin.implicits._</code> to pimp 
 * Spark DataFrames with a <code>natjoin</code> member method. 
 */
object NaturalJoin extends NaturalJoining {  
  object implicits {
    implicit def dfWithNatJoin(df: DataFrame) = DFWithNatJoin(df)
  }
}

If you’re interested in using this code in your own projects, simply add the Silex library to your project and import com.redhat.et.silex.frame._. (You can also get Silex via bintray.)