3
votes

I have case classes in a Scala 2.11 app that have a method that relies on the names of the fields of the case class like this:

final case class Foo(
  val a: String,
  val b: String,
  val c: String
) {
  def partitionColumns: Seq[String] = Seq("b", "c")
}

I want a compile-time check to throw an error if one of the partitionColumns doesn't exist as a field on the case class, for example something that will catch this:

final case class Bar(
  val a: String,
  val b: String,
  val c: String
) {
  def partitionColumns: Seq[String] = Seq("x", "y")
}

So far, I've encapsulated the partitioning behavior in a trait, which reduces the number of times/places that this can go wrong:

sealed trait partitionedByBAndC {
  def b: String
  def c: String
  def partitionColumns: Seq[String] = Seq("b", "c")
}

final case class Foo(
  val a: String,
  val b: String,
  val c: String
) extends PartitionedByBAndC

But if the trait is written incorrectly, there's no check e.g. this code, which compiles fine:

sealed trait partitionedByBAndCIncorrect {
  def b: String
  def c: String
  def partitionColumns: Seq[String] = Seq("x", "y")
}

final case class Foo(
  val a: String,
  val b: String,
  val c: String
) extends partitionedByBAndCIncorrect

In Scala 2.13, I might be able to use productElementNames, but I'm on Scala 2.11 (and Spark 2.3). I'm not sure what to do without actually constructing an object out of the class/trait, which seems like a lot of overhead (considering there are many of these traits in the code).

1
Why would you have these fields?Arnaud Claudel
For writing data out to a partitioned hive table.Michael K
OK, and have you tried to use reflection?Arnaud Claudel
I haven't, I'm not so familiar with it. I'll take a look.Michael K

1 Answers

5
votes

There is a small library (scala-nameOf) which can be used to do this:

final case class Foo(
                    val a: String,
                    val b: String,
                    val c: String
                  ) {
import com.github.dwickern.macros.NameOf._
def partitionColumns: Seq[String] = Seq(nameOf(this.a),nameOf(this.b))

}

This won't compile for fields which are not part of the case class