Here's how to do it with the API functions.
Suppose your DataFrame were the following:
df.show()
df.printSchema()
You can use square brackets to access elements in the letters
column by index, and wrap that in a call to pyspark.sql.functions.array()
to create a new ArrayType
column.
import pyspark.sql.functions as f
df.withColumn("first_two", f.array([f.col("letters")[0], f.col("letters")[1]])).show()
Or if you had too many indices to list, you can use a list comprehension:
df.withColumn("first_two", f.array([f.col("letters")[i] for i in range(2)])).show()
For pyspark versions 2.4+ you can also use pyspark.sql.functions.slice()
:
df.withColumn("first_two",f.slice("letters",start=1,length=2)).show()
slice
may have better performance for large arrays (note that start index is 1, not 0)