Spark data frames from CSV files: handling headers & column types

Christos - Iraklis TsatsoulisBig Data, Spark 16 Comments

If you come from the R (or Python/pandas) universe, like me, you must implicitly think that working with CSV files must be one of the most natural and straightforward things to happen in a data analysis context. Indeed, if you have your data in a CSV file, practically the only thing you have to do from R is to fire a read.csv command (or even better, an fread one from the excellent data.table R package), and your data import part is practically finished. Your data do not fit into main memory? No problem, that’s why R packages like ff and LaF (standing for Large Files) were made for. And in almost all of these cases you do not have to worry about the schema, i.e. data column types (usually they are transparently inferred from the file). Last but not least, having a header (usually with the field names) in the file not only is not a problem, but, on the contrary, it provides at least a good-enough starting point for your column names…

Very few of these conveniences survive if you step out of these R and Python/pandas worlds: CSV file headers in Hadoop are usually a nuisance, which has to be taken care of in order not to mess up with the actual data; other structured data file formats prevail, like json and parquet; and as for automatic schema detection from CSV files, we probably have to wait a little longer…

These last days I have been delving into the recently introduced data frames for Apache Spark (available since version 1.3, released last March), and for a good reason: coming from an R background, I feel ultra-comfortable to work with a data structure that is practically native in R, and I am only excited to have this data structure augmented and extended for the big data analytics era.

So in this post I am going to share my initial journey with Spark data frames, a little further away from the trivial 2-rows-and-2-columns example cases found in the documentation; I will use the Python API (PySpark), which I hope will be of some additional value, since most of the (still sparse, anyway) existing material in the Web usually comes in Scala; and I will use a CSV file with header as a starting point, which you can download here.

After running pyspark from the command line, we get the welcome screen, and we proceed to import the necessary modules and initialize our SQLContext:

Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 1.3.0
      /_/

Using Python version 2.7.6 (default, Mar 22 2014 22:59:56)
SparkContext available as sc, HiveContext available as sqlCtx.
>>> from pyspark.sql import SQLContext
>>> from pyspark.sql.types import *
>>> sqlContext = SQLContext(sc)

(I will be showing along the commands as inserted in the prompt, but I include the whole code (with text and comments) in a downloadable IPython notebook).

Let’s first import our CSV file, and have a first look at it (for convenience, we omit the various Spark diagnostic messages displayed on screen). Be sure to change the first line to reflect the path where you have unzipped and saved the file nyctaxisub.csv:

>>> taxiFile = sc.textFile("file:///home/ctsats/datasets/BDU_Spark/nyctaxisub.csv")
>>> taxiFile.count()
250000
>>> taxiFile.take(5)
[u'"_id","_rev","dropoff_datetime","dropoff_latitude","dropoff_longitude","hack_license","medallion","passenger_count","pickup_datetime","pickup_latitude","pickup_longitude","rate_code","store_and_fwd_flag","trip_distance","trip_time_in_secs","vendor_id"',
 u'"29b3f4a30dea6688d4c289c9672cb996","1-ddfdec8050c7ef4dc694eeeda6c4625e","2013-01-11 22:03:00",+4.07033460000000E+001,-7.40144200000000E+001,"A93D1F7F8998FFB75EEF477EB6077516","68BC16A99E915E44ADA7E639B4DD5F59",2,"2013-01-11 21:48:00",+4.06760670000000E+001,-7.39810790000000E+001,1,,+4.08000000000000E+000,900,"VTS"',
 u'"2a80cfaa425dcec0861e02ae44354500","1-b72234b58a7b0018a1ec5d2ea0797e32","2013-01-11 04:28:00",+4.08190960000000E+001,-7.39467470000000E+001,"64CE1B03FDE343BB8DFB512123A525A4","60150AA39B2F654ED6F0C3AF8174A48A",1,"2013-01-11 04:07:00",+4.07280540000000E+001,-7.40020370000000E+001,1,,+8.53000000000000E+000,1260,"VTS"',
 u'"29b3f4a30dea6688d4c289c96758d87e","1-387ec30eac5abda89d2abefdf947b2c1","2013-01-11 22:02:00",+4.07277180000000E+001,-7.39942860000000E+001,"2D73B0C44F1699C67AB8AE322433BDB7","6F907BC9A85B7034C8418A24A0A75489",5,"2013-01-11 21:46:00",+4.07577480000000E+001,-7.39649810000000E+001,1,,+3.01000000000000E+000,960,"VTS"',
 u'"2a80cfaa425dcec0861e02ae446226e4","1-aa8b16d6ae44ad906a46cc6581ffea50","2013-01-11 10:03:00",+4.07643050000000E+001,-7.39544600000000E+001,"E90018250F0A009433F03BD1E4A4CE53","1AFFD48CC07161DA651625B562FE4D06",5,"2013-01-11 09:44:00",+4.07308080000000E+001,-7.39928280000000E+001,1,,+3.64000000000000E+000,1140,"VTS"']

Our data consist of NYC taxi rides info, including pickup/dropoff datetimes and locations; we have 250,000 records, including a header with the field names; we have missing data (notice the double-comma sequences toward the end of the displayed records, corresponding to the store_and_fwd_flag field); each record is included as a long string  (u' ' ); and all string-type data (including datetimes and field names) are themselves double-quoted. In brief, and apart from the small dataset size, this is arguably a rather realistic situation of a CSV data source.

As already mentioned, at this stage our data is nothing more than a bunch of long string records. As a first step towards building a dataframe, we isolate the header, in order to eventually use it to get the field names:

>>> header = taxiFile.first()
>>> header
u'"_id","_rev","dropoff_datetime","dropoff_latitude","dropoff_longitude","hack_license","medallion","passenger_count","pickup_datetime","pickup_latitude","pickup_longitude","rate_code","store_and_fwd_flag","trip_distance","trip_time_in_secs","vendor_id"'

We want to get rid of these double quotes around the field names, and then use the header to build the fields for our schema, similarly to the relevant example in Spark SQL documentation:

>>> schemaString = header.replace('"','')  # get rid of the double-quotes
>>> schemaString
u'_id,_rev,dropoff_datetime,dropoff_latitude,dropoff_longitude,hack_license,medallion,passenger_count,pickup_datetime,pickup_latitude,pickup_longitude,rate_code,store_and_fwd_flag,trip_distance,trip_time_in_secs,vendor_id'
>>> fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split(',')]
>>> fields
[StructField(_id,StringType,true),
 StructField(_rev,StringType,true),
 StructField(dropoff_datetime,StringType,true),
 StructField(dropoff_latitude,StringType,true),
 StructField(dropoff_longitude,StringType,true),
 StructField(hack_license,StringType,true),
 StructField(medallion,StringType,true),
 StructField(passenger_count,StringType,true),
 StructField(pickup_datetime,StringType,true),
 StructField(pickup_latitude,StringType,true),
 StructField(pickup_longitude,StringType,true),
 StructField(rate_code,StringType,true),
 StructField(store_and_fwd_flag,StringType,true),
 StructField(trip_distance,StringType,true),
 StructField(trip_time_in_secs,StringType,true),
 StructField(vendor_id,StringType,true)]
>>> len(fields)  # how many elements in the header?
16

At this stage, our fields are all StringType; this is deliberate, since it provides a quick way for initialization. Now we can manually go and modify the fields which should not be of type String. By inspection of the data shown so far, it is not difficult to infer the columns which should be of type Float, Integer, and Timestamp. We leave the details to the reader, and we proceed to modify the respective fields so that they reflect the correct data type:

>>> fields[2].dataType = TimestampType()
>>> fields[3].dataType = FloatType()
>>> fields[4].dataType = FloatType()
>>> fields[7].dataType = IntegerType()
>>> fields[8].dataType = TimestampType()
>>> fields[9].dataType = FloatType()
>>> fields[10].dataType = FloatType()
>>> fields[11].dataType = IntegerType()
>>> fields[13].dataType = FloatType()
>>> fields[14].dataType = IntegerType()
>>> fields
[StructField(_id,StringType,true),
 StructField(_rev,StringType,true),
 StructField(dropoff_datetime,TimestampType,true),
 StructField(dropoff_latitude,FloatType,true),
 StructField(dropoff_longitude,FloatType,true),
 StructField(hack_license,StringType,true),
 StructField(medallion,StringType,true),
 StructField(passenger_count,IntegerType,true),
 StructField(pickup_datetime,TimestampType,true),
 StructField(pickup_latitude,FloatType,true),
 StructField(pickup_longitude,FloatType,true),
 StructField(rate_code,IntegerType,true),
 StructField(store_and_fwd_flag,StringType,true),
 StructField(trip_distance,FloatType,true),
 StructField(trip_time_in_secs,IntegerType,true),
 StructField(vendor_id,StringType,true)]

Want to also get rid of these annoying leading underscores in the first two field names? We can change them in a similarly easy and straightforward way:

>>> fields[0].name = 'id'
>>> fields[1].name = 'rev'
>>> fields
[StructField(id,StringType,true),
 StructField(rev,StringType,true),
 StructField(dropoff_datetime,TimestampType,true),
 StructField(dropoff_latitude,FloatType,true),
 StructField(dropoff_longitude,FloatType,true),
 StructField(hack_license,StringType,true),
 StructField(medallion,StringType,true),
 StructField(passenger_count,IntegerType,true),
 StructField(pickup_datetime,TimestampType,true),
 StructField(pickup_latitude,FloatType,true),
 StructField(pickup_longitude,FloatType,true),
 StructField(rate_code,IntegerType,true),
 StructField(store_and_fwd_flag,StringType,true),
 StructField(trip_distance,FloatType,true),
 StructField(trip_time_in_secs,IntegerType,true),
 StructField(vendor_id,StringType,true)]

Now that we are satisfied with the data types, we can construct our schema, which we will use later below for building the data frame:

>>> schema = StructType(fields)

Recall from our introduction above that the existence of the header along with the data in a single file is something that needs to be taken care of. It is rather easy to isolate the header from the actual data, and then drop it using Spark’s .subtract() method for RDD’s:

>>> taxiHeader = taxiFile.filter(lambda l: "_id" in l)
>>> taxiHeader.collect()
[u'"_id","_rev","dropoff_datetime","dropoff_latitude","dropoff_longitude","hack_license","medallion","passenger_count","pickup_datetime","pickup_latitude","pickup_longitude","rate_code","store_and_fwd_flag","trip_distance","trip_time_in_secs","vendor_id"']
>>> taxiNoHeader = taxiFile.subtract(taxiHeader)
>>> taxiNoHeader.count()
249999

Given that we had 250,000 rows in our file, we end up as expected with just one row less, i.e. 249,999. And just to be clear, we could not have used the header variable already calculated, since header is just a local variable and, as such, it cannot be subtracted from an RDD.

We are now almost ready for the final step before actually building our data frame: the idea is that, after the first .map() operation for splitting the row contents using the appropriate field separator (comma in our case), we chain a second one, where we include the fields either as-is, for StringTypes, either with the appropriate conversion, for FloatTypes, IntegerTypes, and TimestampTypes (all available in pyspark.sql.types module). But before doing that, we have to import the necessary Python modules in order to correctly deal with datetimes:

>>> from datetime import *
>>> from dateutil.parser import parse
>>> parse("2013-02-09 18:16:10")  # test it:
datetime.datetime(2013, 2, 9, 18, 16, 10)

Our situation, with all these double-quoted strings nested inside longer single-quoted strings, required some tweaking with parse(); the Python .strip() method for strings comes handy, in order to remove the “internal” double quotes from our datetime strings. Here is the final one-liner code (ignore the odd code coloring produced by the blog plugin, obviously itself confused with all these quotes-within-quotes!):

>>> taxi_temp = taxiNoHeader.map(lambda k: k.split(",")).map(lambda p: (p[0], p[1], parse(p[2].strip('"')), float(p[3]), float(p[4]) , p[5], p[6] , int(p[7]), parse(p[8].strip('"')), float(p[9]), float(p[10]), int(p[11]), p[12], float(p[13]), int(p[14]), p[15] ))
>>> taxi_temp.top(2)  # have a look:
[(u'"fff43e5eb5662eecf42a3f9b5ff42214"',
  u'"1-2e9ea2f49a29663d699d1940f42fab66"',
  datetime.datetime(2013, 11, 26, 13, 15),
  40.764915,
  -73.982536,
  u'"564F38A1BC4B1AA7EC528E6C2C81EAAC"',
  u'"3E29713986A6762D985C4FC53B177F61"',
  1,
  datetime.datetime(2013, 11, 26, 13, 2),
  40.786667,
  -73.972023,
  1,
  u'',
  1.87,
  780,
  u'"VTS"'),
 (u'"fff43e5eb5662eecf42a3f9b5ff1fc5b"',
  u'"1-18b010dab3a3f83ebf4b9f31e88c615d"',
  datetime.datetime(2013, 11, 26, 3, 59),
  40.686081,
  -73.952072,
  u'"5E3208C5FA0E44EA08223489E3853EAD"',
  u'"DC67FC4851D7642EDCA34A8A3C44F116"',
  1,
  datetime.datetime(2013, 11, 26, 3, 42),
  40.740715,
  -74.004562,
  1,
  u'',
  5.84,
  1020,
  u'"VTS"')]

From simple inspection, it seems that indeed all our four different data types are now correctly identified. We are now ready to build our data frame, using the taxi_temp RDD computed above and the schema variable already calculated:

>>> taxi_df = sqlContext.createDataFrame(taxi_temp, schema)
>>> taxi_df.head(10)  # look at the first 10 rows:
[Row(id=u'"e6b3fa7bee24a30c25ce87e44e714457"', rev=u'"1-9313152f4894bb47678d8ce98e9ec733"', dropoff_datetime=datetime.datetime(2013, 2, 9, 18, 16), dropoff_latitude=40.73524856567383, dropoff_longitude=-73.99406433105469, hack_license=u'"88F8DD623E5090083988CD32C84973E3"', medallion=u'"6B96DDFB5A50B96E72F5808ABE778B17"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 9, 17, 59), pickup_latitude=40.775123596191406, pickup_longitude=-73.96345520019531, rate_code=1, store_and_fwd_flag=u'', trip_distance=3.4600000381469727, trip_time_in_secs=1020, vendor_id=u'"VTS"'),
 Row(id=u'"cbee283a4613f85af67f79c6d7721234"', rev=u'"1-c1bd2aecbf3936b30c486aa3deade97b"', dropoff_datetime=datetime.datetime(2013, 1, 11, 17, 2), dropoff_latitude=40.826969146728516, dropoff_longitude=-73.94998931884766, hack_license=u'"5514E59A5CEA0379EA6F7F12ABE87489"', medallion=u'"3541D0677EEEA07B67E645E12F04F517"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 1, 11, 16, 29), pickup_latitude=40.77362823486328, pickup_longitude=-73.87080383300781, rate_code=1, store_and_fwd_flag=u'', trip_distance=8.229999542236328, trip_time_in_secs=1980, vendor_id=u'"VTS"'),
 Row(id=u'"e4fb64b76eb99d4ac222713eb36f1afb"', rev=u'"1-233ff643b7f105b7a76ec05cf4f0f6db"', dropoff_datetime=datetime.datetime(2013, 11, 26, 11, 51, 40), dropoff_latitude=40.76206970214844, dropoff_longitude=-73.96826171875, hack_license=u'"912A2B86F30CDFE246586972A892367E"', medallion=u'"F3241FAB90B4B14FC46C3F11CC14B79E"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 11, 36, 54), pickup_latitude=40.77932357788086, pickup_longitude=-73.97745513916016, rate_code=1, store_and_fwd_flag=u'"N"', trip_distance=1.7000000476837158, trip_time_in_secs=886, vendor_id=u'"CMT"'),
 Row(id=u'"a0dbc88f34c35a620c3a33af7d447bb2"', rev=u'"1-09c485081ed511298abe1d5a0a976e67"', dropoff_datetime=datetime.datetime(2013, 2, 11, 20, 31, 18), dropoff_latitude=40.795536041259766, dropoff_longitude=-73.96687316894531, hack_license=u'"4CDB4439568A22F50E68E6C767583F0E"', medallion=u'"A5A8269908F5D906140559A300992053"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 11, 20, 14, 6), pickup_latitude=40.73963165283203, pickup_longitude=-74.00267028808594, rate_code=1, store_and_fwd_flag=u'"N"', trip_distance=5.300000190734863, trip_time_in_secs=1031, vendor_id=u'"CMT"'),
 Row(id=u'"22d54bc53694ffa796879114d35dde53"', rev=u'"1-239114ce02a0b43667c2f5db2bb5d34f"', dropoff_datetime=datetime.datetime(2013, 11, 26, 8, 59, 34), dropoff_latitude=40.755271911621094, dropoff_longitude=-73.97235107421875, hack_license=u'"C5ADEC336825DEB30222ED03016EC2EA"', medallion=u'"AD1848EF6C8D8D832D8E9C8A83D58E32"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 8, 41, 52), pickup_latitude=40.77080535888672, pickup_longitude=-73.95088195800781, rate_code=1, store_and_fwd_flag=u'"N"', trip_distance=2.0999999046325684, trip_time_in_secs=1061, vendor_id=u'"CMT"'),
 Row(id=u'"57cf267a1fe6533edd94a5883b904a60"', rev=u'"1-0c2111ef3fbd25eb1775ce3fc460de29"', dropoff_datetime=datetime.datetime(2013, 11, 26, 12, 37, 56), dropoff_latitude=40.734100341796875, dropoff_longitude=-73.9888916015625, hack_license=u'"107A492A8269674DF2174B2A33D751C5"', medallion=u'"87D6A5AF77EA7F5F31213AADB50B7508"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 12, 24, 24), pickup_latitude=40.70307159423828, pickup_longitude=-74.01173400878906, rate_code=1, store_and_fwd_flag=u'"N"', trip_distance=4.400000095367432, trip_time_in_secs=811, vendor_id=u'"CMT"'),
 Row(id=u'"9114af73922c7cd9afac08d29f64917c"', rev=u'"1-9092bbcc1ee62333743272cf65ce3277"', dropoff_datetime=datetime.datetime(2013, 1, 11, 8, 38), dropoff_latitude=40.70155334472656, dropoff_longitude=-74.01187133789062, hack_license=u'"562E4437B93311AD764B17344145AA9A"', medallion=u'"1C6C70CC78475DA41DF18E897863F4B0"', passenger_count=2, pickup_datetime=datetime.datetime(2013, 1, 11, 8, 12), pickup_latitude=40.77445602416992, pickup_longitude=-73.95706939697266, rate_code=1, store_and_fwd_flag=u'', trip_distance=8.430000305175781, trip_time_in_secs=1560, vendor_id=u'"VTS"'),
 Row(id=u'"952ae0acb1d3a1dcbe4dbdebbabd81b5"', rev=u'"1-cef51bf1e73f95a3426e974cf6c750e2"', dropoff_datetime=datetime.datetime(2013, 2, 11, 14, 32, 20), dropoff_latitude=40.77259826660156, dropoff_longitude=-73.9824447631836, hack_license=u'"711FF480F454257CDB3DD2E67A080687"', medallion=u'"271217702A1E3484D03FE5B2B3E49146"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 11, 14, 17), pickup_latitude=40.79769515991211, pickup_longitude=-73.97139739990234, rate_code=1, store_and_fwd_flag=u'"N"', trip_distance=1.899999976158142, trip_time_in_secs=919, vendor_id=u'"CMT"'),
 Row(id=u'"5c6680ae704e4ef370cd9d12f5c5b11c"', rev=u'"1-8246c032d15967ee0c8bb8d172d2d58c"', dropoff_datetime=datetime.datetime(2013, 2, 9, 20, 13), dropoff_latitude=40.72455978393555, dropoff_longitude=-74.00943756103516, hack_license=u'"43F2B464214B4F897BAF0D1DA4AF45D2"', medallion=u'"EB41562F6ECB5CA2630A85A1682D57FE"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 9, 20, 0), pickup_latitude=40.726890563964844, pickup_longitude=-73.98908996582031, rate_code=1, store_and_fwd_flag=u'', trip_distance=1.9900000095367432, trip_time_in_secs=780, vendor_id=u'"VTS"'),
 Row(id=u'"b06cb5d08bdc03b787b6f50f6c7bf488"', rev=u'"1-65b6e118b787d00ef2ae18584bb02cd7"', dropoff_datetime=datetime.datetime(2013, 11, 26, 19, 44, 36), dropoff_latitude=40.7413215637207, dropoff_longitude=-73.98870086669922, hack_license=u'"ED3A7E9C15A035BF3E9023240C11E432"', medallion=u'"9E627782FF35E9C2426B997D2C20DA3F"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 19, 27, 54), pickup_latitude=40.76643753051758, pickup_longitude=-73.95418548583984, rate_code=1, store_and_fwd_flag=u'"N"', trip_distance=3.0999999046325684, trip_time_in_secs=1002, vendor_id=u'"CMT"')]

Are we good? Well, not really… It seems that we still carry these quotes-within-quotes in our StringType variables. By now, the remedy should be obvious: just modify the second .map() operation above, so as to include a .strip('"') method call for each of our StringType variables. Instead of leaving this to the reader as a trivial exercise, we grab the opportunity to repeat the process, in order to also demonstrate another way for building data frames directly from the RDD of interest taxiNoHeader, without the need for the temporary RDD taxi_temp: Spark’s rdd.toDF() method:

>>> taxi_df = taxiNoHeader.map(lambda k: k.split(",")).map(lambda p: (p[0].strip('"'), p[1].strip('"'), parse(p[2].strip('"')), float(p[3]), float(p[4]) , p[5].strip('"'), p[6].strip('"') , int(p[7]), parse(p[8].strip('"')), float(p[9]), float(p[10]), int(p[11]), p[12].strip('"'), float(p[13]), int(p[14]), p[15].strip('"')) ).toDF(schema)
>>> taxi_df.head(10)
[Row(id=u'e6b3fa7bee24a30c25ce87e44e714457', rev=u'1-9313152f4894bb47678d8ce98e9ec733', dropoff_datetime=datetime.datetime(2013, 2, 9, 18, 16), dropoff_latitude=40.73524856567383, dropoff_longitude=-73.99406433105469, hack_license=u'88F8DD623E5090083988CD32C84973E3', medallion=u'6B96DDFB5A50B96E72F5808ABE778B17', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 9, 17, 59), pickup_latitude=40.775123596191406, pickup_longitude=-73.96345520019531, rate_code=1, store_and_fwd_flag=u'', trip_distance=3.4600000381469727, trip_time_in_secs=1020, vendor_id=u'VTS'),
 Row(id=u'cbee283a4613f85af67f79c6d7721234', rev=u'1-c1bd2aecbf3936b30c486aa3deade97b', dropoff_datetime=datetime.datetime(2013, 1, 11, 17, 2), dropoff_latitude=40.826969146728516, dropoff_longitude=-73.94998931884766, hack_license=u'5514E59A5CEA0379EA6F7F12ABE87489', medallion=u'3541D0677EEEA07B67E645E12F04F517', passenger_count=1, pickup_datetime=datetime.datetime(2013, 1, 11, 16, 29), pickup_latitude=40.77362823486328, pickup_longitude=-73.87080383300781, rate_code=1, store_and_fwd_flag=u'', trip_distance=8.229999542236328, trip_time_in_secs=1980, vendor_id=u'VTS'),
 Row(id=u'e4fb64b76eb99d4ac222713eb36f1afb', rev=u'1-233ff643b7f105b7a76ec05cf4f0f6db', dropoff_datetime=datetime.datetime(2013, 11, 26, 11, 51, 40), dropoff_latitude=40.76206970214844, dropoff_longitude=-73.96826171875, hack_license=u'912A2B86F30CDFE246586972A892367E', medallion=u'F3241FAB90B4B14FC46C3F11CC14B79E', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 11, 36, 54), pickup_latitude=40.77932357788086, pickup_longitude=-73.97745513916016, rate_code=1, store_and_fwd_flag=u'N', trip_distance=1.7000000476837158, trip_time_in_secs=886, vendor_id=u'CMT'),
 Row(id=u'a0dbc88f34c35a620c3a33af7d447bb2', rev=u'1-09c485081ed511298abe1d5a0a976e67', dropoff_datetime=datetime.datetime(2013, 2, 11, 20, 31, 18), dropoff_latitude=40.795536041259766, dropoff_longitude=-73.96687316894531, hack_license=u'4CDB4439568A22F50E68E6C767583F0E', medallion=u'A5A8269908F5D906140559A300992053', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 11, 20, 14, 6), pickup_latitude=40.73963165283203, pickup_longitude=-74.00267028808594, rate_code=1, store_and_fwd_flag=u'N', trip_distance=5.300000190734863, trip_time_in_secs=1031, vendor_id=u'CMT'),
 Row(id=u'22d54bc53694ffa796879114d35dde53', rev=u'1-239114ce02a0b43667c2f5db2bb5d34f', dropoff_datetime=datetime.datetime(2013, 11, 26, 8, 59, 34), dropoff_latitude=40.755271911621094, dropoff_longitude=-73.97235107421875, hack_license=u'C5ADEC336825DEB30222ED03016EC2EA', medallion=u'AD1848EF6C8D8D832D8E9C8A83D58E32', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 8, 41, 52), pickup_latitude=40.77080535888672, pickup_longitude=-73.95088195800781, rate_code=1, store_and_fwd_flag=u'N', trip_distance=2.0999999046325684, trip_time_in_secs=1061, vendor_id=u'CMT'),
 Row(id=u'57cf267a1fe6533edd94a5883b904a60', rev=u'1-0c2111ef3fbd25eb1775ce3fc460de29', dropoff_datetime=datetime.datetime(2013, 11, 26, 12, 37, 56), dropoff_latitude=40.734100341796875, dropoff_longitude=-73.9888916015625, hack_license=u'107A492A8269674DF2174B2A33D751C5', medallion=u'87D6A5AF77EA7F5F31213AADB50B7508', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 12, 24, 24), pickup_latitude=40.70307159423828, pickup_longitude=-74.01173400878906, rate_code=1, store_and_fwd_flag=u'N', trip_distance=4.400000095367432, trip_time_in_secs=811, vendor_id=u'CMT'),
 Row(id=u'9114af73922c7cd9afac08d29f64917c', rev=u'1-9092bbcc1ee62333743272cf65ce3277', dropoff_datetime=datetime.datetime(2013, 1, 11, 8, 38), dropoff_latitude=40.70155334472656, dropoff_longitude=-74.01187133789062, hack_license=u'562E4437B93311AD764B17344145AA9A', medallion=u'1C6C70CC78475DA41DF18E897863F4B0', passenger_count=2, pickup_datetime=datetime.datetime(2013, 1, 11, 8, 12), pickup_latitude=40.77445602416992, pickup_longitude=-73.95706939697266, rate_code=1, store_and_fwd_flag=u'', trip_distance=8.430000305175781, trip_time_in_secs=1560, vendor_id=u'VTS'),
 Row(id=u'952ae0acb1d3a1dcbe4dbdebbabd81b5', rev=u'1-cef51bf1e73f95a3426e974cf6c750e2', dropoff_datetime=datetime.datetime(2013, 2, 11, 14, 32, 20), dropoff_latitude=40.77259826660156, dropoff_longitude=-73.9824447631836, hack_license=u'711FF480F454257CDB3DD2E67A080687', medallion=u'271217702A1E3484D03FE5B2B3E49146', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 11, 14, 17), pickup_latitude=40.79769515991211, pickup_longitude=-73.97139739990234, rate_code=1, store_and_fwd_flag=u'N', trip_distance=1.899999976158142, trip_time_in_secs=919, vendor_id=u'CMT'),
 Row(id=u'5c6680ae704e4ef370cd9d12f5c5b11c', rev=u'1-8246c032d15967ee0c8bb8d172d2d58c', dropoff_datetime=datetime.datetime(2013, 2, 9, 20, 13), dropoff_latitude=40.72455978393555, dropoff_longitude=-74.00943756103516, hack_license=u'43F2B464214B4F897BAF0D1DA4AF45D2', medallion=u'EB41562F6ECB5CA2630A85A1682D57FE', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 9, 20, 0), pickup_latitude=40.726890563964844, pickup_longitude=-73.98908996582031, rate_code=1, store_and_fwd_flag=u'', trip_distance=1.9900000095367432, trip_time_in_secs=780, vendor_id=u'VTS'),
 Row(id=u'b06cb5d08bdc03b787b6f50f6c7bf488', rev=u'1-65b6e118b787d00ef2ae18584bb02cd7', dropoff_datetime=datetime.datetime(2013, 11, 26, 19, 44, 36), dropoff_latitude=40.7413215637207, dropoff_longitude=-73.98870086669922, hack_license=u'ED3A7E9C15A035BF3E9023240C11E432', medallion=u'9E627782FF35E9C2426B997D2C20DA3F', passenger_count=1, pickup_datetime=datetime.datetime(2013, 11, 26, 19, 27, 54), pickup_latitude=40.76643753051758, pickup_longitude=-73.95418548583984, rate_code=1, store_and_fwd_flag=u'N', trip_distance=3.0999999046325684, trip_time_in_secs=1002, vendor_id=u'CMT')]

Now that we have our dataframe, let’s run some simple pandas-like queries on it. For example, let’s see how many vendors and records per vendor there are in the dataset:

>>> taxi_df.groupBy("vendor_id").count().show()
vendor_id count
CMT       114387
VTS       135612

Recall that we have missing values in the field store_and_fwd_flag. But how many are they?

>>> taxi_df.filter(taxi_df.store_and_fwd_flag == '').count()
135616L

OK, the number of missing values looks dangerously close to the number of VTS vendor records. Is this a coincidence, or vendor VTS indeed tends not to log the subject variable? Let’s explore this hypothesis:

>>> taxi_df.filter(taxi_df.store_and_fwd_flag == '' and taxi_df.vendor_id == 'VTS').count()
135612L

Well, we have a finding! Indeed, all records coming from VTS vendor have missing value in the subject field…
dtypes and printSchema() methods can be used to get information about the schema, which can be useful further down in the data processing pipeline:

>>> taxi_df.dtypes
[('id', 'string'),
 ('rev', 'string'),
 ('dropoff_datetime', 'timestamp'),
 ('dropoff_latitude', 'float'),
 ('dropoff_longitude', 'float'),
 ('hack_license', 'string'),
 ('medallion', 'string'),
 ('passenger_count', 'int'),
 ('pickup_datetime', 'timestamp'),
 ('pickup_latitude', 'float'),
 ('pickup_longitude', 'float'),
 ('rate_code', 'int'),
 ('store_and_fwd_flag', 'string'),
 ('trip_distance', 'float'),
 ('trip_time_in_secs', 'int'),
 ('vendor_id', 'string')]
>>> taxi_df.printSchema()
root
 |-- id: string (nullable = true)
 |-- rev: string (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- dropoff_latitude: float (nullable = true)
 |-- dropoff_longitude: float (nullable = true)
 |-- hack_license: string (nullable = true)
 |-- medallion: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- pickup_latitude: float (nullable = true)
 |-- pickup_longitude: float (nullable = true)
 |-- rate_code: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- trip_distance: float (nullable = true)
 |-- trip_time_in_secs: integer (nullable = true)
 |-- vendor_id: string (nullable = true)

Not familiar with pandas, but a SQL expert? No problem, Spark dataframes provide a SQL API as well. We first have to register the dataframe as a named temporary table; then, we can run the equivalents of the above queries as shown below:

>>> taxi_df.registerTempTable("taxi")
>>> sqlContext.sql("SELECT vendor_id, COUNT(*) FROM taxi GROUP BY vendor_id ").show()
vendor_id c1
CMT       114387
VTS       135612
>>> sqlContext.sql("SELECT COUNT(*) FROM taxi WHERE store_and_fwd_flag = '' ").show()
c0
135616
>>> sqlContext.sql("SELECT COUNT(*) FROM taxi WHERE vendor_id = 'VTS' AND store_and_fwd_flag = '' ").show()
c0
135612

Notice that, unlike standard SQL, table and column names are case sensitive, i.e. TAXI or vendor_ID in the queries will produce an error.

Now, imagine that at this point we want to change some column names: say, we want to shorten pickup_latitude to pickup_lat, and similarly for the other 3 columns with lat/long information; we certainly do not want to run all the above procedure from the beginning – or even we might not have access to the initial CSV data, but only to the dataframe. We can do that using the dataframe method withColumnRenamed, chained as many times as required:

>>> taxi_df = taxi_df.withColumnRenamed('dropoff_longitude', 'dropoff_long').withColumnRenamed('dropoff_latitude', 'dropoff_lat').withColumnRenamed('pickup_latitude', 'pickup_lat').withColumnRenamed('pickup_longitude', 'pickup_long')
>>> taxi_df.dtypes
[('id', 'string'),
 ('rev', 'string'),
 ('dropoff_datetime', 'timestamp'),
 ('dropoff_lat', 'float'),
 ('dropoff_long', 'float'),
 ('hack_license', 'string'),
 ('medallion', 'string'),
 ('passenger_count', 'int'),
 ('pickup_datetime', 'timestamp'),
 ('pickup_lat', 'float'),
 ('pickup_long', 'float'),
 ('rate_code', 'int'),
 ('store_and_fwd_flag', 'string'),
 ('trip_distance', 'float'),
 ('trip_time_in_secs', 'int'),
 ('vendor_id', 'string')]

Finally, let’s make a selection from our dataframe and convert the selected rows to pandas format. Of course, with the size of the dataset in our case, we can directly convert all of it to a pandas dataframe; however, this will not be the case in a real situation, where the dataset may involve millions of rows and hundreds of gigabytes. So, let’s demonstrate a row selection.
Let’s say that we want to keep only the records from vendor CMT that do not have missing values in store_and_fwd_flag column, and store the result in a pandas dataframe:

>>> import pandas as pd
>>> taxi_CMT = taxi_df.filter("vendor_id = 'CMT' and store_and_fwd_flag != '' ").toPandas()
>>> taxi_CMT.head()
                                 id                                 rev  \
0  e4fb64b76eb99d4ac222713eb36f1afb  1-233ff643b7f105b7a76ec05cf4f0f6db
1  a0dbc88f34c35a620c3a33af7d447bb2  1-09c485081ed511298abe1d5a0a976e67
2  22d54bc53694ffa796879114d35dde53  1-239114ce02a0b43667c2f5db2bb5d34f
3  57cf267a1fe6533edd94a5883b904a60  1-0c2111ef3fbd25eb1775ce3fc460de29
4  952ae0acb1d3a1dcbe4dbdebbabd81b5  1-cef51bf1e73f95a3426e974cf6c750e2   

     dropoff_datetime  dropoff_lat  dropoff_long  \
0 2013-11-26 11:51:40    40.762070    -73.968262
1 2013-02-11 20:31:18    40.795536    -73.966873
2 2013-11-26 08:59:34    40.755272    -73.972351
3 2013-11-26 12:37:56    40.734100    -73.988892
4 2013-02-11 14:32:20    40.772598    -73.982445   

                       hack_license                         medallion  \
0  912A2B86F30CDFE246586972A892367E  F3241FAB90B4B14FC46C3F11CC14B79E
1  4CDB4439568A22F50E68E6C767583F0E  A5A8269908F5D906140559A300992053
2  C5ADEC336825DEB30222ED03016EC2EA  AD1848EF6C8D8D832D8E9C8A83D58E32
3  107A492A8269674DF2174B2A33D751C5  87D6A5AF77EA7F5F31213AADB50B7508
4  711FF480F454257CDB3DD2E67A080687  271217702A1E3484D03FE5B2B3E49146   

   passenger_count     pickup_datetime  pickup_lat  pickup_long  rate_code  \
0                1 2013-11-26 11:36:54   40.779324   -73.977455          1
1                1 2013-02-11 20:14:06   40.739632   -74.002670          1
2                1 2013-11-26 08:41:52   40.770805   -73.950882          1
3                1 2013-11-26 12:24:24   40.703072   -74.011734          1
4                1 2013-02-11 14:17:00   40.797695   -73.971397          1   

  store_and_fwd_flag  trip_distance  trip_time_in_secs vendor_id
0                  N            1.7                886       CMT
1                  N            5.3               1031       CMT
2                  N            2.1               1061       CMT
3                  N            4.4                811       CMT
4                  N            1.9                919       CMT  

[5 rows x 16 columns]

taxi_CMT is now a pandas dataframe, and all the relevant functionality is available…

Row selection using numeric or string column values is as straightforward as demonstrated above. For timestamp columns, things are more complicated, and we’ll cover this issue in a future post.

* * *

 We hope we have given a handy demonstration on how to construct Spark dataframes from CSV files with headers. There exist already some third-party external packages, like [EDIT: spark-csv and] pyspark-csv, that attempt to do this in an automated manner, more or less similar to R’s read.csv or pandas’ read_csv, which we have not tried yet, and we also hope to do so in a near-future post.
As said, all the code demonstrated above can be found in a downloadable IPython notebook. Comments and remarks are most welcome.-

Christos - Iraklis Tsatsoulis
Latest posts by Christos - Iraklis Tsatsoulis (see all)
Subscribe
Notify of
16 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
Reynold Xin
Reynold Xin
May 29, 2015 21:11

Thanks for the article. Reynold here who wrote most of the DataFrame API. It is still experimental, so we would love to see more feedback if you have any (my email rxin at databricks.com). There are two ways to easily use CSV.

One is https://github.com/databricks/spark-csv

The other is to load it in using Pandas, and just call sqlContext.createDataFrame(pandasDF).

Parag
Parag
Reply to  Christos - Iraklis Tsatsoulis
July 19, 2015 14:33

Hi, this undoubtedly is very informative, thanks for the post. however i am new to spark and python and facing problem in getting the values out of dataframe/row and to load them into variables for further processing. for example if after getting the max of pickup_datetime, i want to store it in variable, how do i do it?

Patrick
Patrick
September 18, 2015 12:45

Hey thanks for that article. I really liked it but facing myself with that stuff I got some problems. Ât the moment I try to find out how to update a column (datetime.datetime to date) or add a new column (based on other columns) to my df.

If you have any suggestions it would be very kind.

trackback
December 22, 2015 16:23

[…] writing the previous post on Spark dataframes, I encountered an unexpected behavior of the respective .filter method; but, on the one hand, I […]

Stella
Stella
April 15, 2016 21:45

Thanks for the article! Question: it seems that you (and every other post I’ve found) is manually casting the rdd data types. is there a way to do this programmatically (similar to what structtype does for dataframes? ) Context/ my problem: I have a data.csv file , without headers. I also have a metadata.csv which contains column names, and their respective data types. I used the metadata.csv to generate a structtype which i named final_schema. I would like to pull my data.csv into a dataframe with the appropriate schema applied. When I attempt to do : df = sqlContext.read.format(‘csv’).load( data.csv,… Read more »

trackback
July 4, 2016 13:02

[…] new columns from existing ones in Spark dataframes is a frequently raised question – see Patrick’s comment in our previous post); then, we will check in how many records this is false (i.e. dropoff seems to […]

Nandha
Nandha
December 15, 2016 22:44

Best explained.. Thanks CHRISTOS – IRAKLIS TSATSOULIS for sharing the knowledge.

Soufiane
Soufiane
December 25, 2016 05:14

Great article/tutorial , i like it.
Many thanks

Joel
Joel
March 2, 2017 16:55

Insanely helpful! Thank you!

Manish Gupta
Manish Gupta
March 12, 2018 11:51

Great Article, but when I am trying to built a Hive table on top of this, I am getting error. Temporary table is created but actual Hive table is not.

Ron Zbaida
Ron Zbaida
July 29, 2018 19:30

a lifesave post!