4
votes

Suppose I use partitionBy to save some data to disk, e.g. by date so my data looks like this:

/mydata/d=01-01-2018/part-00000
/mydata/d=01-01-2018/part-00001
...
/mydata/d=02-01-2018/part-00000
/mydata/d=02-01-2018/part-00001
...

When I read the data using Hive config and DataFrame, so

val df = sparkSession.sql(s"select * from $database.$tableName")

I can know that:

  • Filter queries on column d will push down
  • No shuffles will occur if I try to partition by d (e.g. GROUP BY d)

BUT, suppose I don't know what the partition key is (some upstream job writes the data, and has no conventions). How can I get Spark to tell me which is the partition key, in this case d. Similarly if we have multiple partitions (e.g. by month, week, then day).

Currently the best code we have is really ugly:

def getPartitionColumnsForHiveTable(databaseTableName: String)(implicit sparkSession: SparkSession): Set[String] = {
    val cols = sparkSession.
      sql(s"desc $databaseTableName")
      .select("col_name")
      .collect
      .map(_.getAs[String](0))
      .dropWhile(r => !r.matches("# col_name"))
    if (cols.isEmpty) {
      Set()
    } else {
      cols.tail.toSet
    }
  }
3
are you sure no shuffles will occur in this case? I thought only bucketed hive tables have this behavior? - Raphael Roth
@RaphaelRoth I might be out of date, Spark changes the way files are read into partitions seemingly on every release (so what was once true isn't always true). - samthebest

3 Answers

4
votes

Assuming you don't have = and / in your partitioned column values, you can do:

val df = spark.sql("show partitions database.test_table")

val partitionedCols: Set[String] = try { 
  df.map(_.getAs[String](0)).first.split('/').map(_.split("=")(0)).toSet
} catch {
  case e: AnalysisException => Set.empty[String]
}

You should get an Array[String] with the partitioned column names.

1
votes

you can use sql statements to get this info, either show create table <tablename>, describe extended <tablename> or show partitions <tablename>. The last one gives the simplest output to parse:

val partitionCols = spark.sql("show partitions <tablename>").as[String].first.split('/').map(_.split("=").head)
0
votes

Use the metadata to get the partition column names in a comma-separated string. First check if the table is partitioned, if true get the partition columns

val table = "default.country"

def isTablePartitioned(spark:org.apache.spark.sql.SparkSession, table:String) :Boolean = {
    val col_details = spark.sql(s" describe extended ${table} ").select("col_name").select(collect_list(col("col_name"))).as[Array[String]].first
    col_details.filter( x => x.contains("# Partition Information" )).length > 0
}


def getPartitionColumns(spark:org.apache.spark.sql.SparkSession, table:String): String = {
    val pat =  """(?ms)^\s*#( Partition Information)(.+)(Detailed Table Information)\s*$""".r
    val col_details = spark.sql(s" describe extended ${table} ").select("col_name").select(collect_list(col("col_name"))).as[Array[String]].first
    val col_details2 = col_details.filter( _.trim.length > 0 ).mkString("\n")
    val arr = pat.findAllIn(col_details2).matchData.collect{ case pat(a,b,c) => b }.toList(0).split("\n").filterNot( x => x.contains("#") ).filter( _.length > 0 )
    arr.mkString(",")
}

if( isTablePartitioned(spark,table) ) 
 { 
    getPartitionColumns(spark,table) 
 } 
 else 
 { 
    "--NO_PARTITIONS--" 
 }

Note: The other 2 answers assume the table to have data which will fail, if the table is empty.