This isn't really a Spark question as much as a Scala type-related question, but what I'm doing might be of interest to Spark fans, so I am keeping the 'Spark' in my framing of the question, which is:
I want to recursively transform a spark sql schema of StructType, which contains a list whose elements may be either StructType's or StructField's. The result of the transform should be a version of the original schema which disallows nulls in any field. Unfortunately, StructType and StructField don't extend from a common marker trait. This lead to my initial implementation where the method accepted 'Any' and explicitly cast the result back to StructType.
Initial Implementation
object SchemaTools extends App {
import org.apache.spark.sql.types._
def noNullSchema(schema: StructType): StructType = {
def go(element: Any): Product = element match {
case x: StructField => x.copy(nullable = false)
case x: StructType => StructType(x.fields.map(_.copy(nullable = false)))
case bad => sys.error(s"element of unexpected type: $bad")
}
go(schema).asInstanceOf[StructType]
}
type Rec = (String, Seq[(Int, Int, String)])
val schema: StructType = Encoders.product[Rec].schema
System.out.println("pr:" + schema.prettyJson)
System.out.println("pr:" + noNullSchema(schema).prettyJson)
}
UPDATE
I am accepting Tim's answer since he kindly pointed out my dumb mistake that I wasn't recursing down into the nested structure. I have included a modified version of the above "proof of concept" of a de-nullifier below. This works on my example input and illustrates the general approach I would take. With this implementation I have no issues related to types. My bad! : I misunderstood what goes inside a StructType (it is always an array of StructField, not an array of either StructField OR StructType). The fields within the array may themselves be of datatype "StructType" which drives the need for recursion. Anyway... below is a revised "toy" implementation that illustrates how I might tackle this problem if I needed a full-on solution (instead of just implementing for learning's sake). This code is definitely not production ready, and will fail on more complex inputs. It illustrates a possible approach though.
Note: One other thing I learned about nulls and schemas that is very important to keep in mind.... Even if one correctly implemented a schema "de-nuller" Spark would not enforce nullability checks during parsing. This is discussed in more detail here: Nullability in Spark sql schemas is advisory by default. What is best way to strictly enforce it?
*Proof of Concept ... No Longer Has Issues With Types *
object SchemaTools extends App {
import org.apache.spark.sql.types._
def noNullSchema(field: StructField): StructField = {
field.dataType match {
case ArrayType(StructType(fields), containsNull) =>
StructField(
field.name,
ArrayType(noNullSchema(StructType(fields)), containsNull),
nullable = false,
field.metadata)
case _ => field.copy(nullable = false)
}
}
def noNullSchema(schema: StructType): StructType =
StructType (
schema.fields.map { f =>
System.out.println("f:" + f);
noNullSchema(f)
}
)
type Rec = (String, Seq[(Int, String, String)])
val schema: StructType = Encoders.product[Rec].schema
System.out.println("pr:" + schema.prettyJson)
System.out.println("pr:" + noNullSchema(schema).prettyJson)
}
noNullSchemameans thatelementis alwaysStructTypeso there is no need to makegogeneric. - Tim