I am using Spark 2.3.0 and trying the pandas_udf user-defined functions within my Pyspark code. According to https://github.com/apache/spark/pull/20114, ArrayType is currently supported. My user-defined function is:
def transform(c):
if not any(isinstance(x, (list, tuple, np.ndarray)) for x in c.values):
nvalues = c.values
else:
nvalues = np.array(c.values.tolist())
tvalues = some_external_function(nvalues)
if not any(isinstance(y, (list, tuple, np.ndarray)) for y in tvalues):
p = pd.Series(np.array(tvalues))
else:
p = pd.Series(list(tvalues))
return p
transform = pandas_udf(transform, ArrayType(LongType()))
When i am applying this function to a specific array column of a large Spark Dataframe, then i notice that the first element of the pandas series c, has different double size compared to the others, and the last one has 0 size:
0 [73, 10, 223, 46, 14, 73, 14, 5, 14, 21, 10, 2...
1 [223, 46, 14, 73, 14, 5, 14, 21, 30, 16]
2 [46, 14, 73, 14, 5, 14, 21, 30, 16, 15]
...
4695 []
Name: _70, Length: 4696, dtype: object
With the first array having 20 elements, the second 10 (which is the correct one), and the last one 0. And then of course the c.values fails with ValueError: setting an array element with a sequence.
, since the array has multiple sizes.
When i am trying the same function to column with array of strings, then all sizes are correct, and the rest of the functions steps as well.
Any idea what might be the issue? Possible bug?
UPDATED with example:
I am attaching a simple example, just printing the values inside the pandas_udf function.
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import SparkSession
if __name__ == "__main__":
spark = SparkSession\
.builder\
.appName("testing pandas_udf")\
.getOrCreate()
arr = []
for i in range(100000):
arr.append([2,2,2,2,2])
df = spark.createDataFrame(arr, ArrayType(LongType()))
def transform(c):
print(c)
print(c.values)
return c
transform = pandas_udf(transform, ArrayType(LongType()))
df = df.withColumn('new_value', transform(col('value')))
df.show()
spark.stop()
If you check executor's log, you will see logs like:
0 [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
1 [2, 2, 2, 2, 2]
2 [2, 2, 2, 2, 2]
3 [2, 2, 2, 2, 2]
4 [2, 2, 2, 2, 2]
5 [2, 2, 2, 2, 2]
...
9996 [2, 2, 2, 2, 2]
9997 [2, 2, 2, 2, 2]
9998 []
9999 []
Name: _0, Length: 10000, dtype: object
SOLVED:
If you face the same issue, upgrade to Spark 2.3.1 and pyarrow 0.9.0.post1.