I have a PySpark dataframe (say df
) which has two columns ( Name
and Score
). Following is an example of the dataframe:
+------+-----+
| Name|Score|
+------+-----+
| name1|11.23|
| name2|14.57|
| name3| 2.21|
| name4| 8.76|
| name5|18.71|
+------+-----+
I have a numpy array (say bin_array
) which has values close to the numerical values that are there in the column titled Score
of the PySpark dataframe.
Following is the aforementioned numpy array:
bin_array = np.array([0, 5, 10, 15, 20])
I want to compare value from each row of the column Score
with values in bin_array
and store the closest value (gotten from bin_array
) in a separate column in the PySpark dataframe.
Below is how I would like my new dataframe (say df_new
) to look.
+------+-----+------------+
| Name|Score| Closest_bin|
+------+-----+------------+
| name1|11.23| 10.0 |
| name2|14.57| 15.0 |
| name3| 2.21| 0.0 |
| name4| 8.76| 10.0 |
| name5|18.71| 20.0 |
+------+-----+------------+
I have the below mentioned function which gives me the closest values from bin_array
. The function works fine when I test it with individual numbers.
def find_nearest(array, value):
array = np.asarray(array)
idx = (np.abs(array - value)).argmin()
return float(array[idx])
In my actual work, I will have millions of rows in the datafrmae. What is the most efficient way to create df_new
?
Following are the steps that I tried to use to create user-defined function (udf) and the new data frame (df_new
).
closest_bin_udf = F.udf( lambda x: find_nearest(array, x) )
df_new = df.withColumn( 'Closest_bin' , closest_bin_udf(df.Score) )
But, I got errors when I tried df_new.show()
. A portion of the error is shown below.
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-11-685c9b7e25d9> in <module>()
----> 1 df_new.show()
/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
376 """
377 if isinstance(truncate, bool) and truncate:
--> 378 print(self._jdf.showString(n, 20, vertical))
379 else:
380 print(self._jdf.showString(n, int(truncate), vertical))
You can use the below mentioned steps to create the aforementioned dataframe:
from pyspark.sql import *
import pyspark.sql.functions as F
import numpy as np
Stats = Row("Name", "Score")
stat1 = Stats('name1', 11.23)
stat2 = Stats('name2', 14.57)
stat3 = Stats('name3', 2.21)
stat4 = Stats('name4', 8.76)
stat5 = Stats('name5', 18.71)
stat_lst = [stat1 , stat2, stat3, stat4, stat5]
df = spark.createDataFrame(stat_lst)
df.show()