Using Spark to get the names of all columns that have a value over some threshold

Bakground

We are dumping data from Redshift to S3 and then loading it into a dataframe like this:

df = spark.read.csv(path, schema=schema, sep='|')

      

We are using PySpark and AWS EMR (version 5.4.0) with Spark 2.1.0.

Problem

I have a Redshift table that is read into PySpark as CSV. The records are in this format:

url,category1,category2,category3,category4
http://example.com,0.6,0.0,0.9,0.3

      

url is VARCHAR and category values ​​are FLOAT between 0.0 and 1.0.

I want to create a new DataFrame with one row for each category where the value in the original dataset was above some threshold X. For example, if the threshold was set to 0.5, then I would like my new dataset to look like this:

url,category
http://example.com,category1
http://example.com,category3

      

I'm new to Spark / PySpark so I'm not sure how to go about this, if possible. Any help would be appreciated!

EDIT:

Required to add my solution (based on Pushkr code). We have a TON of categories to download, so to avoid hardcoding each individual selection, I did the following:

parsed_df = None
for column in column_list:
    if not parsed_df:
        parsed_df = df.select(df.url, when(df[column]>threshold,column).otherwise('').alias('cat'))
    else:
        parsed_df = parsed_df.union(df.select(df.url, when(df[column]>threshold,column).otherwise('')))
if parsed_df is not None:
    parsed_df = parsed_df.filter(col('cat') != '')

      

where column_list is the previously generated list of category column names, and the threshold is the minimum value required to select a category.

Thanks again!

+3


source to share


1 answer


Here is what I have tried -

data = [('http://example.com',0.6,0.0,0.9,0.3),('http://example1.com',0.6,0.0,0.9,0.3)]

df = spark.createDataFrame(data)\
     .toDF('url','category1','category2','category3','category4')

from pyspark.sql.functions import *



df\
    .select(df.url,when(df.category1>0.5,'category1').otherwise('').alias('category'))\
    .union(\
    df.select(df.url,when(df.category2>0.5,'category2').otherwise('')))\
    .union(\
    df.select(df.url,when(df.category3>0.5,'category3').otherwise('')))\
    .union(\
    df.select(df.url,when(df.category4>0.5,'category4').otherwise('')))\
    .filter(col('category')!= '')\
    .show()

      



output:

+-------------------+---------+
|                url| category|
+-------------------+---------+
| http://example.com|category1|
|http://example1.com|category1|
| http://example.com|category3|
|http://example1.com|category3|
+-------------------+---------+

      

+1


source







All Articles