29
votes

I have a DataFrame with the schema

root
 |-- label: string (nullable = true)
 |-- features: struct (nullable = true)
 |    |-- feat1: string (nullable = true)
 |    |-- feat2: string (nullable = true)
 |    |-- feat3: string (nullable = true)

While, I am able to filter the data frame using

  val data = rawData
     .filter( !(rawData("features.feat1") <=> "100") )

I am unable to drop the columns using

  val data = rawData
       .drop("features.feat1")

Is it something that I am doing wrong here? I also tried (unsuccessfully) doing drop(rawData("features.feat1")), though it does not make much sense to do so.

Thanks in advance,

Nikhil

9
What if you mapped it to a new dataframe instead? I don't think the DataFrame API allows you to drop a struct field within a struct column type. - Jeff L
Ohh. I will try that, but seems pretty inconvenient if I have to map just to resolve a nested column name this way :(. - Nikhil J Joshi
You can always get all columns with the DataFrame's .columns() method, remove unwanted column from the sequence and do select(myColumns:_*). Should be a bit shorter. - TheMP

9 Answers

29
votes

It is just a programming exercise but you can try something like this:

import org.apache.spark.sql.{DataFrame, Column}
import org.apache.spark.sql.types.{StructType, StructField}
import org.apache.spark.sql.{functions => f}
import scala.util.Try

case class DFWithDropFrom(df: DataFrame) {
  def getSourceField(source: String): Try[StructField] = {
    Try(df.schema.fields.filter(_.name == source).head)
  }

  def getType(sourceField: StructField): Try[StructType] = {
    Try(sourceField.dataType.asInstanceOf[StructType])
  }

  def genOutputCol(names: Array[String], source: String): Column = {
    f.struct(names.map(x => f.col(source).getItem(x).alias(x)): _*)
  }

  def dropFrom(source: String, toDrop: Array[String]): DataFrame = {
    getSourceField(source)
      .flatMap(getType)
      .map(_.fieldNames.diff(toDrop))
      .map(genOutputCol(_, source))
      .map(df.withColumn(source, _))
      .getOrElse(df)
  }
}

Example usage:

scala> case class features(feat1: String, feat2: String, feat3: String)
defined class features

scala> case class record(label: String, features: features)
defined class record

scala> val df = sc.parallelize(Seq(record("a_label",  features("f1", "f2", "f3")))).toDF
df: org.apache.spark.sql.DataFrame = [label: string, features: struct<feat1:string,feat2:string,feat3:string>]

scala> DFWithDropFrom(df).dropFrom("features", Array("feat1")).show
+-------+--------+
|  label|features|
+-------+--------+
|a_label| [f2,f3]|
+-------+--------+


scala> DFWithDropFrom(df).dropFrom("foobar", Array("feat1")).show
+-------+----------+
|  label|  features|
+-------+----------+
|a_label|[f1,f2,f3]|
+-------+----------+


scala> DFWithDropFrom(df).dropFrom("features", Array("foobar")).show
+-------+----------+
|  label|  features|
+-------+----------+
|a_label|[f1,f2,f3]|
+-------+----------+

Add an implicit conversion and you're good to go.

22
votes

This version allows you to remove nested columns at any level:

import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, DataType}

/**
  * Various Spark utilities and extensions of DataFrame
  */
object DataFrameUtils {

  private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
    if (fullColName.equals(dropColName)) {
      None
    } else {
      colType match {
        case colType: StructType =>
          if (dropColName.startsWith(s"${fullColName}.")) {
            Some(struct(
              colType.fields
                .flatMap(f =>
                  dropSubColumn(col.getField(f.name), f.dataType, s"${fullColName}.${f.name}", dropColName) match {
                    case Some(x) => Some(x.alias(f.name))
                    case None => None
                  })
                : _*))
          } else {
            Some(col)
          }
        case other => Some(col)
      }
    }
  }

  protected def dropColumn(df: DataFrame, colName: String): DataFrame = {
    df.schema.fields
      .flatMap(f => {
        if (colName.startsWith(s"${f.name}.")) {
          dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
            case Some(x) => Some((f.name, x))
            case None => None
          }
        } else {
          None
        }
      })
      .foldLeft(df.drop(colName)) {
        case (df, (colName, column)) => df.withColumn(colName, column)
      }
  }

  /**
    * Extended version of DataFrame that allows to operate on nested fields
    */
  implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
    /**
      * Drops nested field from DataFrame
      *
      * @param colName Dot-separated nested field name
      */
    def dropNestedColumn(colName: String): DataFrame = {
      DataFrameUtils.dropColumn(df, colName)
    }
  }
}

Usage:

import DataFrameUtils._
df.dropNestedColumn("a.b.c.d")
5
votes

Expanding on spektom answer. With support for array types:

object DataFrameUtils {

  private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
    if (fullColName.equals(dropColName)) {
      None
    } else if (dropColName.startsWith(s"$fullColName.")) {
      colType match {
        case colType: StructType =>
          Some(struct(
            colType.fields
              .flatMap(f =>
                dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
                  case Some(x) => Some(x.alias(f.name))
                  case None => None
                })
              : _*))
        case colType: ArrayType =>
          colType.elementType match {
            case innerType: StructType =>
              Some(struct(innerType.fields
                .flatMap(f =>
                  dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
                    case Some(x) => Some(x.alias(f.name))
                    case None => None
                  })
                : _*))
          }

        case other => Some(col)
      }
    } else {
      Some(col)
    }
  }

  protected def dropColumn(df: DataFrame, colName: String): DataFrame = {
    df.schema.fields
      .flatMap(f => {
        if (colName.startsWith(s"${f.name}.")) {
          dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
            case Some(x) => Some((f.name, x))
            case None => None
          }
        } else {
          None
        }
      })
      .foldLeft(df.drop(colName)) {
        case (df, (colName, column)) => df.withColumn(colName, column)
      }
  }

  /**
    * Extended version of DataFrame that allows to operate on nested fields
    */
  implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
    /**
      * Drops nested field from DataFrame
      *
      * @param colName Dot-separated nested field name
      */
    def dropNestedColumn(colName: String): DataFrame = {
      DataFrameUtils.dropColumn(df, colName)
    }
  }

}
5
votes

I will expand upon mmendez.semantic's answer here, and accounting for the issues described in the sub-thread.

  def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
    if (fullColName.equals(dropColName)) {
      None
    } else if (dropColName.startsWith(s"$fullColName.")) {
      colType match {
        case colType: StructType =>
          Some(struct(
            colType.fields
                .flatMap(f =>
                  dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
                    case Some(x) => Some(x.alias(f.name))
                    case None => None
                  })
                : _*))
        case colType: ArrayType =>
          colType.elementType match {
            case innerType: StructType =>
              // we are potentially dropping a column from within a struct, that is itself inside an array
              // Spark has some very strange behavior in this case, which they insist is not a bug
              // see https://issues.apache.org/jira/browse/SPARK-31779 and associated comments
              // and also the thread here: https://stackoverflow.com/a/39943812/375670
              // this is a workaround for that behavior

              // first, get all struct fields
              val innerFields = innerType.fields
              // next, create a new type for all the struct fields EXCEPT the column that is to be dropped
              // we will need this later
              val preserveNamesStruct = ArrayType(StructType(
                innerFields.filterNot(f => s"$fullColName.${f.name}".equals(dropColName))
              ))
              // next, apply dropSubColumn recursively to build up the new values after dropping the column
              val filteredInnerFields = innerFields.flatMap(f =>
                dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
                    case Some(x) => Some(x.alias(f.name))
                    case None => None
                }
              )
              // finally, use arrays_zip to unwrap the arrays that were introduced by building up the new. filtered
              // struct in this way (see comments in SPARK-31779), and then cast to the StructType we created earlier
              // to get the original names back
              Some(arrays_zip(filteredInnerFields:_*).cast(preserveNamesStruct))
          }

        case _ => Some(col)
      }
    } else {
      Some(col)
    }
  }

  def dropColumn(df: DataFrame, colName: String): DataFrame = {
    df.schema.fields.flatMap(f => {
      if (colName.startsWith(s"${f.name}.")) {
        dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
          case Some(x) => Some((f.name, x))
          case None => None
        }
      } else {
        None
      }
    }).foldLeft(df.drop(colName)) {
      case (df, (colName, column)) => df.withColumn(colName, column)
    }
  }

Usage in spark-shell:

// if defining the functions above in your spark-shell session, you first need imports
import org.apache.spark.sql._
import org.apache.spark.sql.types._

// now you can paste the function definitions

// create a deeply nested and complex JSON structure    
val jsonData = """{
      "foo": "bar",
      "top": {
        "child1": 5,
        "child2": [
          {
            "child2First": "one",
            "child2Second": 2,
            "child2Third": -19.51
          }
        ],
        "child3": ["foo", "bar", "baz"],
        "child4": [
          {
            "child2First": "two",
            "child2Second": 3,
            "child2Third": 16.78
          }
        ]
      }
    }"""

// read it into a DataFrame
val df = spark.read.option("multiline", "true").json(Seq(jsonData).toDS())

// remove a sub-column
val modifiedDf = dropColumn(df, "top.child2.child2First")

modifiedDf.printSchema
root
 |-- foo: string (nullable = true)
 |-- top: struct (nullable = false)
 |    |-- child1: long (nullable = true)
 |    |-- child2: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- child2Second: long (nullable = true)
 |    |    |    |-- child2Third: double (nullable = true)
 |    |-- child3: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- child4: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- child2First: string (nullable = true)
 |    |    |    |-- child2Second: long (nullable = true)
 |    |    |    |-- child2Third: double (nullable = true)


modifiedDf.show(truncate=false)
+---+------------------------------------------------------+
|foo|top                                                   |
+---+------------------------------------------------------+
|bar|[5, [[2, -19.51]], [foo, bar, baz], [[two, 3, 16.78]]]|
+---+------------------------------------------------------+
3
votes

Following spektom's code snippet for scala, I've created a similar code in Java. Since java 8 doesn't have foldLeft, I used forEachOrdered. This code is suitable for spark 2.x (I'm using 2.1) Also I noted that dropping a column and adding it using withColumn with the same name doesn't work, so I'm just replacing the column, and it seem to work.

Code is not fully tested, hope it works :-)

public class DataFrameUtils {

public static Dataset<Row> dropNestedColumn(Dataset<Row> dataFrame, String columnName) {
    final DataFrameFolder dataFrameFolder = new DataFrameFolder(dataFrame);
    Arrays.stream(dataFrame.schema().fields())
        .flatMap( f -> {
           if (columnName.startsWith(f.name() + ".")) {
               final Optional<Column> column = dropSubColumn(col(f.name()), f.dataType(), f.name(), columnName);
               if (column.isPresent()) {
                   return Stream.of(new Tuple2<>(f.name(), column));
               } else {
                   return Stream.empty();
               }
           } else {
               return Stream.empty();
           }
        }).forEachOrdered(colTuple -> dataFrameFolder.accept(colTuple));

    return dataFrameFolder.getDF();
}

private static Optional<Column> dropSubColumn(Column col, DataType colType, String fullColumnName, String dropColumnName) {
    Optional<Column> column = Optional.empty();
    if (!fullColumnName.equals(dropColumnName)) {
        if (colType instanceof StructType) {
            if (dropColumnName.startsWith(fullColumnName + ".")) {
                column = Optional.of(struct(getColumns(col, (StructType)colType, fullColumnName, dropColumnName)));
            }
        } else {
            column = Optional.of(col);
        }
    }

    return column;
}

private static Column[] getColumns(Column col, StructType colType, String fullColumnName, String dropColumnName) {
    return Arrays.stream(colType.fields())
        .flatMap(f -> {
                    final Optional<Column> column = dropSubColumn(col.getField(f.name()), f.dataType(),
                            fullColumnName + "." + f.name(), dropColumnName);
                    if (column.isPresent()) {
                        return Stream.of(column.get().alias(f.name()));
                    } else {
                        return Stream.empty();
                    }
                }
        ).toArray(Column[]::new);

}

private static class DataFrameFolder implements Consumer<Tuple2<String, Optional<Column>>> {
    private Dataset<Row> df;

    public DataFrameFolder(Dataset<Row> df) {
        this.df = df;
    }

    public Dataset<Row> getDF() {
        return df;
    }

    @Override
    public void accept(Tuple2<String, Optional<Column>> colTuple) {
        if (!colTuple._2().isPresent()) {
            df = df.drop(colTuple._1());
        } else {
            df = df.withColumn(colTuple._1(), colTuple._2().get());
        }
    }
}

Usage example:

private class Pojo {
    private String str;
    private Integer number;
    private List<String> strList;
    private Pojo2 pojo2;

    public String getStr() {
        return str;
    }

    public Integer getNumber() {
        return number;
    }

    public List<String> getStrList() {
        return strList;
    }

    public Pojo2 getPojo2() {
        return pojo2;
    }

}

private class Pojo2 {
    private String str;
    private Integer number;
    private List<String> strList;

    public String getStr() {
        return str;
    }

    public Integer getNumber() {
        return number;
    }

    public List<String> getStrList() {
        return strList;
    }

}

SQLContext context = new SQLContext(new SparkContext("local[1]", "test"));
Dataset<Row> df = context.createDataFrame(Collections.emptyList(), Pojo.class);
Dataset<Row> dfRes = DataFrameUtils.dropNestedColumn(df, "pojo2.str");

Original struct:

root
 |-- number: integer (nullable = true)
 |-- pojo2: struct (nullable = true)
 |    |-- number: integer (nullable = true)
 |    |-- str: string (nullable = true)
 |    |-- strList: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- str: string (nullable = true)
 |-- strList: array (nullable = true)
 |    |-- element: string (containsNull = true)

After drop:

root
 |-- number: integer (nullable = true)
 |-- pojo2: struct (nullable = false)
 |    |-- number: integer (nullable = true)
 |    |-- strList: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- str: string (nullable = true)
 |-- strList: array (nullable = true)
 |    |-- element: string (containsNull = true)
3
votes

Another (PySpark) way would be to drop the features.feat1 column by creating features again:

from pyspark.sql.functions import col, arrays_zip

display(df
        .withColumn("features", arrays_zip("features.feat2", "features.feat3"))
        .withColumn("features", col("features").cast(schema))
)

Where schema is the new schema (excluding features.feat1).

from pyspark.sql.types import StructType, StructField, StringType

schema = StructType(
    [
      StructField('feat2', StringType(), True), 
      StructField('feat3', StringType(), True), 
    ]
  )
2
votes

PySpark implementation

import pyspark.sql.functions as sf

def _drop_nested_field(
    schema: StructType,
    field_to_drop: str,
    parents: List[str] = None,
) -> Column:
    parents = list() if parents is None else parents
    src_col = lambda field_names: sf.col('.'.join(f'`{c}`' for c in field_names))

    if '.' in field_to_drop:
        root, subfield = field_to_drop.split('.', maxsplit=1)
        field_to_drop_from = next(f for f in schema.fields if f.name == root)

        return sf.struct(
            *[src_col(parents + [f.name]) for f in schema.fields if f.name != root],
            _drop_nested_field(
                schema=field_to_drop_from.dataType,
                field_to_drop=subfield,
                parents=parents + [root]
            ).alias(root)
        )

    else:
        # select all columns except the one to drop
        return sf.struct(
            *[src_col(parents + [f.name])for f in schema.fields if f.name != field_to_drop],
        )


def drop_nested_field(
    df: DataFrame,
    field_to_drop: str,
) -> DataFrame:
    if '.' in field_to_drop:
        root, subfield = field_to_drop.split('.', maxsplit=1)
        field_to_drop_from = next(f for f in df.schema.fields if f.name == root)

        return df.withColumn(root, _drop_nested_field(
            schema=field_to_drop_from.dataType,
            field_to_drop=subfield,
            parents=[root]
        ))
    else:
        return df.drop(field_to_drop)


df = drop_nested_field(df, 'a.b.c.d')
0
votes

Adding the java version Solution for this.

Utility Class(Pass your dataset and the nested column which has to be dropped to dropNestedColumn function).

(There are few bugs in Lior Chaga's answer, I have corrected them while I tried to use his answer).

public class NestedColumnActions {
/*
dataset : dataset in which we want to drop columns
columnName : nested column that needs to be deleted
*/
public static Dataset<?> dropNestedColumn(Dataset<?> dataset, String columnName) {

    //Special case of top level column deletion
    if(!columnName.contains("."))
        return dataset.drop(columnName);

    final DataSetModifier dataFrameFolder = new DataSetModifier(dataset);
    Arrays.stream(dataset.schema().fields())
            .flatMap(f -> {
                //If the column name to be deleted starts with current top level column
                if (columnName.startsWith(f.name() + DOT)) {
                    //Get new column structure under f , expected after deleting the required column
                    final Optional<Column> column = dropSubColumn(functions.col(f.name()), f.dataType(), f.name(), columnName);
                    if (column.isPresent()) {
                        return Stream.of(new Tuple2<>(f.name(), column));
                    } else {
                        return Stream.empty();
                    }
                } else {
                    return Stream.empty();
                }
            })
            //Call accept function with Tuples of (top level column name, new column structure under it)
            .forEach(colTuple -> dataFrameFolder.accept(colTuple));

    return dataFrameFolder.getDataset();
}

private static Optional<Column> dropSubColumn(Column col, DataType colType, String fullColumnName, String dropColumnName) {
    Optional<Column> column = Optional.empty();
    if (!fullColumnName.equals(dropColumnName)) {
        if (colType instanceof StructType) {
            if (dropColumnName.startsWith(fullColumnName + DOT)) {
                column = Optional.of(functions.struct(getColumns(col, (StructType) colType, fullColumnName, dropColumnName)));
            }
            else {
                column = Optional.of(col);
            }
        } else {
            column = Optional.of(col);
        }
    }

    return column;
}

private static Column[] getColumns(Column col, StructType colType, String fullColumnName, String dropColumnName) {
    return Arrays.stream(colType.fields())
            .flatMap(f -> {
                        final Optional<Column> column = dropSubColumn(col.getField(f.name()), f.dataType(),
                                fullColumnName + "." + f.name(), dropColumnName);
                        if (column.isPresent()) {
                            return Stream.of(column.get().alias(f.name()));
                        } else {
                            return Stream.empty();
                        }
                    }
            ).toArray(Column[]::new);

}

private static class DataSetModifier implements Consumer<Tuple2<String, Optional<Column>>> {
    private Dataset<?> df;

    public DataSetModifier(Dataset<?> df) {
        this.df = df;
    }

    public Dataset<?> getDataset() {
        return df;
    }

    /*
    colTuple[0]:top level column name
    colTuple[1]:new column structure under it
   */
    @Override
    public void accept(Tuple2<String, Optional<Column>> colTuple) {
        if (!colTuple._2().isPresent()) {
            df = df.drop(colTuple._1());
        } else {
            df = df.withColumn(colTuple._1(), colTuple._2().get());
        }
    }
}

}

0
votes

The Make Structs Easy* library makes it easy to perform operations like adding, dropping, and renaming fields inside nested data structures. The library is available in both Scala and Python.

Assuming you have the following data:

import org.apache.spark.sql.functions._

case class Features(feat1: String, feat2: String, feat3: String)
case class Record(features: Features, arrayOfFeatures: Seq[Features])

val df = Seq(
   Record(Features("hello", "world", "!"), Seq(Features("red", "orange", "yellow"), Features("green", "blue", "indigo")))
).toDF

df.printSchema

// root
//  |-- features: struct (nullable = true)
//  |    |-- feat1: string (nullable = true)
//  |    |-- feat2: string (nullable = true)
//  |    |-- feat3: string (nullable = true)
//  |-- arrayOfFeatures: array (nullable = true)
//  |    |-- element: struct (containsNull = true)
//  |    |    |-- feat1: string (nullable = true)
//  |    |    |-- feat2: string (nullable = true)
//  |    |    |-- feat3: string (nullable = true)

df.show(false)

// +-----------------+----------------------------------------------+
// |features         |arrayOfFeatures                               |
// +-----------------+----------------------------------------------+
// |[hello, world, !]|[[red, orange, yellow], [green, blue, indigo]]|
// +-----------------+----------------------------------------------+

Then dropping feat2 from features is as simple as:

import com.github.fqaiser94.mse.methods._

// drop feat2 from features
df.withColumn("features", $"features".dropFields("feat2")).show(false)

// +----------+----------------------------------------------+
// |features  |arrayOfFeatures                               |
// +----------+----------------------------------------------+
// |[hello, !]|[[red, orange, yellow], [green, blue, indigo]]|
// +----------+----------------------------------------------+

I noticed there were a lot of follow-up comments on other solutions asking if there's a way to drop a Column nested inside a struct nested inside of an array. This can be done by combining the functions provided by the Make Structs Easy library with the functions provided by spark-hofs library, as follows:

import za.co.absa.spark.hofs._

// drop feat2 in each element of arrayOfFeatures
df.withColumn("arrayOfFeatures", transform($"arrayOfFeatures", features => features.dropFields("feat2"))).show(false)

// +-----------------+--------------------------------+
// |features         |arrayOfFeatures                 |
// +-----------------+--------------------------------+
// |[hello, world, !]|[[red, yellow], [green, indigo]]|
// +-----------------+--------------------------------+

*Full disclosure: I am the author of the Make Structs Easy library that is referenced in this answer.