Asked  7 Months ago    Answers:  5   Viewed   171 times

Spark now offers predefined functions that can be used in dataframes, and it seems they are highly optimized. My original question was going to be on which is faster, but I did some testing myself and found the spark functions to be about 10 times faster at least in one instance. Does anyone know why this is so, and when would a udf be faster (only for instances that an identical spark function exists)?

Here is my testing code (ran on Databricks community ed):

# UDF vs Spark function
from faker import Factory
from pyspark.sql.functions import lit, concat
fake = Factory.create()
fake.seed(4321)

# Each entry consists of last_name, first_name, ssn, job, and age (at least 1)
from pyspark.sql import Row
def fake_entry():
  name = fake.name().split()
  return (name[1], name[0], fake.ssn(), fake.job(), abs(2016 - fake.date_time().year) + 1)

# Create a helper function to call a function repeatedly
def repeat(times, func, *args, **kwargs):
    for _ in xrange(times):
        yield func(*args, **kwargs)
data = list(repeat(500000, fake_entry))
print len(data)
data[0]

dataDF = sqlContext.createDataFrame(data, ('last_name', 'first_name', 'ssn', 'occupation', 'age'))
dataDF.cache()

UDF function:

concat_s = udf(lambda s: s+ 's')
udfData = dataDF.select(concat_s(dataDF.first_name).alias('name'))
udfData.count()

Spark Function:

spfData = dataDF.select(concat(dataDF.first_name, lit('s')).alias('name'))
spfData.count()

Ran both multiple times, the udf usually took about 1.1 - 1.4 s, and the Spark concat function always took under 0.15 s.

 Answers

19

when would a udf be faster

If you ask about Python UDF the answer is probably never*. Since SQL functions are relatively simple and are not designed for complex tasks it is pretty much impossible compensate the cost of repeated serialization, deserialization and data movement between Python interpreter and JVM.

Does anyone know why this is so

The main reasons are already enumerated above and can be reduced to a simple fact that Spark DataFrame is natively a JVM structure and standard access methods are implemented by simple calls to Java API. UDF from the other hand are implemented in Python and require moving data back and forth.

While PySpark in general requires data movements between JVM and Python, in case of low level RDD API it typically doesn't require expensive serde activity. Spark SQL adds additional cost of serialization and serialization as well cost of moving data from and to unsafe representation on JVM. The later one is specific to all UDFs (Python, Scala and Java) but the former one is specific to non-native languages.

Unlike UDFs, Spark SQL functions operate directly on JVM and typically are well integrated with both Catalyst and Tungsten. It means these can be optimized in the execution plan and most of the time can benefit from codgen and other Tungsten optimizations. Moreover these can operate on data in its "native" representation.

So in a sense the problem here is that Python UDF has to bring data to the code while SQL expressions go the other way around.


* According to rough estimates PySpark window UDF can beat Scala window function.

Tuesday, June 1, 2021
 
JohnnyW
answered 7 Months ago
88

User defined functions are defined for up to 22 parameters. Only udf helper is define for at most 10 arguments. To handle functions with larger number of parameters you can use org.apache.spark.sql.UDFRegistration.

For example

val dummy = ((
  x0: Int, x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, 
  x8: Int, x9: Int, x10: Int, x11: Int, x12: Int, x13: Int, x14: Int, 
  x15: Int, x16: Int, x17: Int, x18: Int, x19: Int, x20: Int, x21: Int) => 1)

van be registered:

import org.apache.spark.sql.expressions.UserDefinedFunction

val dummyUdf: UserDefinedFunction = spark.udf.register("dummy", dummy)

and use directly

val df = spark.range(1)
val exprs =  (0 to 21).map(_ => lit(1))

df.select(dummyUdf(exprs: _*))

or by name via callUdf

import org.apache.spark.sql.functions.callUDF

df.select(
  callUDF("dummy", exprs:  _*).alias("dummy")
)

or SQL expression:

df.selectExpr(s"""dummy(${Seq.fill(22)(1).mkString(",")})""")

You can also create an UserDefinedFunction object:

import org.apache.spark.sql.expressions.UserDefinedFunction

Seq(1).toDF.select(UserDefinedFunction(dummy, IntegerType, None)(exprs: _*))

In practice having a function with 22 arguments is not very useful and unless you want to use Scala reflection to generate these there are maintenance nightmare.

I would either consider using collections (array, map) or struct as an input or divide this into multiple modules. For example:

val aLongArray = array((0 to 256).map(_ => lit(1)): _*)

val udfWitharray = udf((xs: Seq[Int]) => 1)

Seq(1).toDF.select(udfWitharray(aLongArray).alias("dummy"))
Saturday, June 19, 2021
 
octern
answered 6 Months ago
33

What is you see is a difference between implementation of Limit (an transformation-like operation) and CollectLimit (an action-like operation). However the difference in timings is highly misleading, and not something you can expect in general case.

First let's create a MCVE

spark.conf.set("spark.sql.files.maxPartitionBytes", 500)

val ds = spark.read
  .text("README.md")
  .as[String]
  .map{ x => {
    Thread.sleep(1000)
    x
   }}

val dsLimit4 = ds.limit(4)

make sure we start with clean slate:

spark.sparkContext.statusTracker.getJobIdsForGroup(null).isEmpty
Boolean = true

invoke count:

dsLimit4.count()

and take a look at the execution plan (from Spark UI):

== Parsed Logical Plan ==
Aggregate [count(1) AS count#12L]
+- GlobalLimit 4
   +- LocalLimit 4
      +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
         +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
            +- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
               +- Relation[value#0] text

== Analyzed Logical Plan ==
count: bigint
Aggregate [count(1) AS count#12L]
+- GlobalLimit 4
   +- LocalLimit 4
      +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
         +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
            +- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
               +- Relation[value#0] text

== Optimized Logical Plan ==
Aggregate [count(1) AS count#12L]
+- GlobalLimit 4
   +- LocalLimit 4
      +- Project
         +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
            +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
               +- DeserializeToObject value#0.toString, obj#5: java.lang.String
                  +- Relation[value#0] text

== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[count(1)], output=[count#12L])
+- *(2) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#15L])
   +- *(2) GlobalLimit 4
      +- Exchange SinglePartition
         +- *(1) LocalLimit 4
            +- *(1) Project
               +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
                  +- *(1) MapElements <function1>, obj#6: java.lang.String
                     +- *(1) DeserializeToObject value#0.toString, obj#5: java.lang.String
                        +- *(1) FileScan text [value#0] Batched: false, Format: Text, Location: InMemoryFileIndex[file:/path/to/README.md], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<value:string>

The core component is

+- *(2) GlobalLimit 4
   +- Exchange SinglePartition
      +- *(1) LocalLimit 4

which indicates that we can expect a wide operation with multiple stages. We can see a single job

spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(0)

with two stages

spark.sparkContext.statusTracker.getJobInfo(0).get.stageIds
Array[Int] = Array(0, 1)

with eight

spark.sparkContext.statusTracker.getStageInfo(0).get.numTasks
Int = 8

and one

spark.sparkContext.statusTracker.getStageInfo(1).get.numTasks
Int = 1

task respectively.

Now let's compare it to

dsLimit4.take(300).size

which generates following

== Parsed Logical Plan ==
GlobalLimit 300
+- LocalLimit 300
   +- GlobalLimit 4
      +- LocalLimit 4
         +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
            +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
               +- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
                  +- Relation[value#0] text

== Analyzed Logical Plan ==
value: string
GlobalLimit 300
+- LocalLimit 300
   +- GlobalLimit 4
      +- LocalLimit 4
         +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
            +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
               +- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
                  +- Relation[value#0] text

== Optimized Logical Plan ==
GlobalLimit 4
+- LocalLimit 4
   +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
      +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
         +- DeserializeToObject value#0.toString, obj#5: java.lang.String
            +- Relation[value#0] text

== Physical Plan ==
CollectLimit 4
+- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
   +- *(1) MapElements <function1>, obj#6: java.lang.String
      +- *(1) DeserializeToObject value#0.toString, obj#5: java.lang.String
         +- *(1) FileScan text [value#0] Batched: false, Format: Text, Location: InMemoryFileIndex[file:/path/to/README.md], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<value:string>

While both global and local limits still occur, there is no exchange in the middle. Therefore we can expect a single stage operation. Please note that planner narrowed down limit to more restrictive value.

As expected we see a single new job:

spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(1, 0)

which generated only one stage:

spark.sparkContext.statusTracker.getJobInfo(1).get.stageIds
Array[Int] = Array(2)

with only one task

spark.sparkContext.statusTracker.getStageInfo(2).get.numTasks
Int = 1

What does it mean for us?

  • In the count case Spark used wide transformation and actually applies LocalLimit on each partition and shuffles partial results to perform GlobalLimit.
  • In the take case Spark used narrow transformation and evaluated LocalLimit only on the first partition.

Obviously the latter approach won't work with number of values in the first partition is lower than the requested limit.

val dsLimit105 = ds.limit(105) // There are 105 lines

In such case the first count will use exactly the same logic as before (I encourage you to confirm that empirically), but take will take rather different path. So far we triggered only two jobs:

spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(1, 0)

Now if we execute

dsLimit105.take(300).size

you'll see that it required 3 more jobs:

spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(4, 3, 2, 1, 0)

So what's going on here? As noted before evaluating a single partition is not enough to satisfy limit in general case. In such case Spark iteratively evaluates LocalLimit on partitions, until GlobalLimit is satisfied, increasing number of partitions taken in each iteration.

Such strategy can have significant performance implications. Starting Spark jobs alone is not cheap and in cases, when upstream object is a result of wide transformation things can get quite ugly (in the best case scenario you can read shuffle files, but if these are lost for some reason, Spark might be forced to re-execute all the dependencies).

To summarize:

  • take is an action, and can short circuit in specific cases where upstream process is narrow, and LocalLimits can be satisfy GlobalLimits using the first few partitions.
  • limit is a transformation, and always evaluates all LocalLimits, as there is no iterative escape hatch.

While one can behave better than the other in specific cases, there not exchangeable and neither guarantees better performance in general.

Friday, July 23, 2021
 
TecHunter
answered 5 Months ago
81

There's a section on the Databricks spark-xml Github page which talks about parsing nested xml, and it provides a solution using the Scala API, as well as a couple of Pyspark helper functions to work around the issue that there is no separate Python package for spark-xml. So using these, here's one way you could solve the problem:

# 1. Copy helper functions from https://github.com/databricks/spark-xml#pyspark-notes

from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.types import _parse_datatype_json_string
import pyspark.sql.functions as F


def ext_from_xml(xml_column, schema, options={}):
    java_column = _to_java_column(xml_column.cast('string'))
    java_schema = spark._jsparkSession.parseDataType(schema.json())
    scala_map = spark._jvm.org.apache.spark.api.python.PythonUtils.toScalaMap(options)
    jc = spark._jvm.com.databricks.spark.xml.functions.from_xml(
        java_column, java_schema, scala_map)
    return Column(jc)

def ext_schema_of_xml_df(df, options={}):
    assert len(df.columns) == 1

    scala_options = spark._jvm.PythonUtils.toScalaMap(options)
    java_xml_module = getattr(getattr(
        spark._jvm.com.databricks.spark.xml, "package$"), "MODULE$")
    java_schema = java_xml_module.schema_of_xml_df(df._jdf, scala_options)
    return _parse_datatype_json_string(java_schema.json())

# 2. Set up example dataframe

xml = '<?xml version="1.0" encoding="utf-8"?> <visitors> <visitor id="9615" age="68" sex="F" /> <visitor id="1882" age="34" sex="M" /> <visitor id="5987" age="23" sex="M" /> </visitors>'

df = spark.createDataFrame([('1',xml)],['id','visitors'])
df.show()

# +---+--------------------+
# | id|            visitors|
# +---+--------------------+
# |  1|<?xml version="1....|
# +---+--------------------+

# 3. Get xml schema and parse xml column

payloadSchema = ext_schema_of_xml_df(df.select("visitors"))
parsed = df.withColumn("parsed", ext_from_xml(F.col("visitors"), payloadSchema))
parsed.show()

# +---+--------------------+--------------------+
# | id|            visitors|              parsed|
# +---+--------------------+--------------------+
# |  1|<?xml version="1....|[[[, 68, 9615, F]...|
# +---+--------------------+--------------------+

# 4. Extract 'visitor' field from StructType
df2 = parsed.select(*parsed.columns[:-1],F.explode(F.col('parsed').getItem('visitor')))
df2.show()

# +---+--------------------+---------------+
# | id|            visitors|            col|
# +---+--------------------+---------------+
# |  1|<?xml version="1....|[, 68, 9615, F]|
# |  1|<?xml version="1....|[, 34, 1882, M]|
# |  1|<?xml version="1....|[, 23, 5987, M]|
# +---+--------------------+---------------+

# 5. Get field names, which will become new columns
# (there's probably a much better way of doing this :D)
new_col_names = [s.split(':')[0] for s in payloadSchema['visitor'].simpleString().split('<')[-1].strip('>>').split(',')]

new_col_names

# ['_VALUE', '_age', '_id', '_sex']

# 6. Create new columns

for c in new_col_names:
    df2 = df2.withColumn(c, F.col('col').getItem(c))
    
df2 = df2.drop('col','_VALUE')

df2.show()

# +---+--------------------+----+----+----+
# | id|            visitors|_age| _id|_sex|
# +---+--------------------+----+----+----+
# |  1|<?xml version="1....|  68|9615|   F|
# |  1|<?xml version="1....|  34|1882|   M|
# |  1|<?xml version="1....|  23|5987|   M|
# +---+--------------------+----+----+----+

One thing to look out for is the new column names duplicating existing column names - in this case the new column names are all preceded by underscores so we don't have any duplication, but it's probably good to check that the nested xml tags don't conflict with existing column names beforehand.

Wednesday, October 20, 2021
 
SkyNet
answered 2 Months ago
25

You should write all columns staticly. For example:

from pyspark.sql import functions as F

# create sample df
df = sc.parallelize([
     (1, 'b'),
     (1, 'c'),

 ]).toDF(["id", "category"])

#simple filter function
@F.udf(returnType=BooleanType())
def my_filter(col1, col2):
    return (col1>0) & (col2=="b")

df.filter(my_filter('id', 'category')).show()

Results:

+---+--------+
| id|category|
+---+--------+
|  1|       b|
+---+--------+

If you have so many columns and you are sure to order of columns:

cols = df.columns
df.filter(my_filter(*cols)).show()

Yields the same output.

Thursday, October 21, 2021
 
Rasmus Puls
answered 2 Months ago
Only authorized users can answer the question. Please sign in first, or register a free account.
Not the answer you're looking for? Browse other questions tagged :  
Share