2
votes

The below question has solution for scala and pyspark and the solution provided in this question is not for consecutive index values.

Spark Dataframe :How to add a index Column : Aka Distributed Data Index

I have an existing Dataset in Apache-spark and i want to select some rows from it based on the index. I am planning to add one index column that contains unique values staring from 1 and based on the values of that column i will fetch rows. I found below method to add index that uses order by:

df.withColumn("index", functions.row_number().over(Window.orderBy("a column")));

I do not want to use order by. I need index in the same order they are present in Dataset. Any help?

2
I've rewritten my answer in Java. Let me know if it works. Also, I'm not convinced this is a duplicate since the answer in Java is quite different, and much more verbose.Oli

2 Answers

1
votes

From what I gather, you are trying to add an index (with consecutive values) to a dataframe. Unfortunately, there is no built in function that does that in Spark. You can only add an increasing index (but not necessarily with consecutive values) with df.withColumn("index", monotonicallyIncreasingId).

Nonetheless, there exists a zipWithIndex function in the RDD API that does exactly what you need. We can thus define a function that transforms the dataframe into a RDD, adds the index and transforms it back into a dataframe.

I'm not an expert in spark in java (scala is much more compact) so it might be possible to do better. Here is how I would do it.

public static Dataset<Row> zipWithIndex(Dataset<Row> df, String name) {
    JavaRDD<Row> rdd = df.javaRDD().zipWithIndex().map(t -> {
        Row r = t._1;
        Long index = t._2 + 1;
        ArrayList<Object> list = new ArrayList<>();
        r.toSeq().iterator().foreach(x -> list.add(x));
        list.add(index);
        return RowFactory.create(list);
    });
    StructType newSchema = df.schema()
            .add(new StructField(name, DataTypes.LongType, true, null));
    return df.sparkSession().createDataFrame(rdd, newSchema);
}

And here is how you would use it. Notice what the built in spark function does in contrast with what our approach does.

Dataset<Row> df = spark.range(5)
    .withColumn("index1", functions.monotonicallyIncreasingId());
Dataset<Row> result = zipWithIndex(df, "good_index");
// df
+---+-----------+
| id|     index1|
+---+-----------+
|  0|          0|
|  1| 8589934592|
|  2|17179869184|
|  3|25769803776|
|  4|25769803777|
+---+-----------+

// result
+---+-----------+----------+
| id|     index1|good_index|
+---+-----------+----------+
|  0|          0|         1|
|  1| 8589934592|         2|
|  2|17179869184|         3|
|  3|25769803776|         4|
|  4|25769803777|         5|
+---+-----------+----------+
0
votes

The above answer worked for me with some adjustments. Below is a functional Intellij Scratch file. I'm on Spark 2.3.0:

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;

class Scratch {
    public static void main(String[] args) {
        SparkSession spark = SparkSession
                    .builder()
                    .appName("_LOCAL")
                    .master("local")
                    .getOrCreate();
        Dataset<Row> df = spark.range(5)
                .withColumn("index1", functions.monotonicallyIncreasingId());
        Dataset<Row> result = zipWithIndex(df, "good_index");
        result.show();
    }
    public static Dataset<Row> zipWithIndex(Dataset<Row> df, String name) {
        JavaRDD<Row> rdd = df.javaRDD().zipWithIndex().map(t -> {
            Row r = t._1;
            Long index = t._2 + 1;
            ArrayList<Object> list = new ArrayList<>();
            scala.collection.Iterator<Object> iterator = r.toSeq().iterator();
            while(iterator.hasNext()) {
                Object value = iterator.next();
                assert value != null;
                list.add(value);
            }
            list.add(index);
            return RowFactory.create(list.toArray());
        });
        StructType newSchema = df.schema()
                .add(new StructField(name, DataTypes.LongType, true, Metadata.empty()));
        return df.sparkSession().createDataFrame(rdd, newSchema);
    }
}

Output:

+---+------+----------+
| id|index1|good_index|
+---+------+----------+
|  0|     0|         1|
|  1|     1|         2|
|  2|     2|         3|
|  3|     3|         4|
|  4|     4|         5|
+---+------+----------+