Logistic regression using Apache Spark

This article aims to be an introduction to the Apache Spark data processing engine, with focus on the machine learning algorithms. We will start with a few words about Spark, then we will begin a practical machine learning exercise. We will use the Qualitative Bankruptcy dataset from UCI Machine Learning Repository. Although Spark support meanwhile Java, Scala, Python and R, for this tutorial we will use Scala as the programming language. Don’t worry if you have no experience using Scala. Every code snippet will be explained as we advance with the exercise.

At the end of the article you can find a link to GitHub where we have a Python version of the demo.

Apache Spark

Apache Spark is open source cluster computing framework, which for certain applications can be up to 100 times faster than the Hadoop MapReduce paradigm. One of the main features Spark  offers for speed is the ability to run computations in memory, but the system is also more efficient than MapReduce for complex applications running on disk. Spark also aims to be general purpose and for that reason it offers the following libraries:

  • Spark SQL,  module for working with structured data
  • MLlib,  scalable machine learning library
  • GraphX, API for graphs and graph-parallel computation
  • Spark Streaming, API for scalable fault-tolerant streaming applications

As already mentioned, Spark supports Java, Scala, Python and R as programming languages. It also integrates closely with other Big data tools. In particular, Spark can run in Hadoop clusters and can access any Hadoop data source including Cassandra.

Spark core concepts

At a high level, a Spark application consists of a driver program that launches various parallel operations on a cluster. The driver program contains the main function of your application which will be then distributed to the clusters members for execution. The SparkContext object is used by the driver program to access the computing cluster. For the shell applications the SparkContext is by default available through the sc variable.

A very important concept in Spark is RDD – resilient distributed dataset. This is an immutable collection of objects. Each RDD is plit into multiple partitions, which may be computed on different nodes of the cluster. RDD can contain any type of object from Java, Scala, Python or R including user-defined classes. The RDDs can be created in two ways: by loading an external dataset or by distributing a collection of objetcs like list or sets.

After creation we can have two types of operation on the RDDS:

  • Transformations –  construct a new RDD from an existing one
  • Actions – compute a result based on an RDD

RDDs are computed in a lazy way – that is when they are used in an action.  Once Spark sees a chain of transformations, it can compute just the data needed for its result. Each time an action is run on an RDD, it  is recomputed. If you need the RDD for multiple actions, you can ask Spark to persist it using RDD.persist().

You can use Spark from a shell session or as a standalone program. Either way you  will have the following workflow:

  • create input RDDs
  • transform them using transformations
  • ask Spark to persist them if needed for reuse
  • launch actions to start parallel computation, which is then optimized and executed by Spark

Installing Apache Spark

To start using Spark download it from http://spark.apache.org/downloads.html. Select the package type “Pre-built for Hadoop 2.4 and later” and click “Direct Download”. Windows users are advised to install Spark into a directory without spaces in the name. For instance extract the archive to C:\spark.

As mentioned before we will be using the Scala language. For this, navigate to the directory where Spark was extracted and run the following command.

At the end you should see the Scala prompt available to take your commands:

Qualitative Bankruptcy classification

The real life problem we are trying to solve is the prediction of the bankruptcy for a given company, given qualitative information about that company. The data set can be downloaded from UCI Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/Qualitative_Bankruptcy. Inside Spark‘s installation folder, create a new folder named playground for instance. Copy here the Qualitative_Bankruptcy.data.txt  file from the archive. This will be our training data.

The data set contains 250 instances, among them 143 instances for non-bankruptcy  and 107 instances for bankruptcy.

For each instance the qualitative data available is:

  • Industrial Risk
  • Management Risk
  • Financial Flexibility
  • Credibility
  • Competitiveness
  • Operating Risk

Those are called qualitative parameters because they cannot be  be represented as a numbers. Each parameter can take the following value:

  • P positive
  • A average
  • N negative

The last column in the data set is the categorization of each instance : B for bankruptcy or NB for non-bankruptcy.

Given this data set, we must train a model which can be used then to classify new data instances – typical classification problem.

The approach for this problem is the following:

  • read the data from the Qualitative_Bankruptcy.data.txt file
  • parse each qualitative value and transform it in a double numeric value. This is needed by our classification algorithm.
  • split the data in training and test data sets
  • train the model using the training data
  • calculate the training error on the test data

Spark logistic regression

We will use Spark‘s logistic regression algorithm to train the classification model. If you need to know more how the logistic regression works you can read also the following tutorial https://technobium.com/logistic-regression-using-apache-mahout.

At the Scala shell type or paste the following import statements:

This will import the needed libraries.

Next we will create a Scala function which will transform each qualitative data in the data set into a double numeric value. Type or paste the following snippet and hit enter on the Scala shell (Thank you David Edwards for optimizing the function).

If all runs without problems you should see a confirmation like this:

Now, we will read the lines in the Qualitative_Bankruptcy.data.txt file. From Spark‘s point of view, this is a transformation. At this stage, the data is not actually read into memory. As mentioned before, the reading is executed in a lazy way. The actual reading is triggered by the count() call, which is an action.

Using the val keyword we have declared a constant named data. This will hold for each line in the input file an RDD. The reading is triggered from the sc or SparkContext context variable. The count should return the following result:

It is now the time to parse all the data and prepare it for the logistic regression algorithm, which operates with numerical values, not Strings.

Here we have declared another constant named parsedData. For each line in the data variable we will do the following:

  • split the value on “,” character and obtain a vector of parts
  • create and return a LabeledPoint object. Each LabeledPoint contains a Label and a Vector of values. In our training data, the Label or the category (bankruptcy or non-bankruptcy) is found on the last column – index 6 starting from 0. This is what we did using parts(6). Before saving the Label, we transform the String to a Double value using the getDoubleValue() method we prepared before. The rest of the values are also transformed to Double and saved in a data structure named dense vector. This is also a data structure needed by the Spark‘s logistic regression algorithm.

From Spark‘s perspective, we have here a map() transformation, which will be first executed when an action is encountered.

Let’s see how our data looks by using a simple take() transformation.

This tells Spark to take 10 samples from the parsedData array and print them to the console. This action triggers the actual execution of the previous map() transformation. The result should be the following:

We will now split the parsedData into the  60% training values and  40% test values.

You can inspect the results by issuing take() or count() actions on trainingData and testData.

We can now train the model using the LogisticRegressioinWithLBFGS() by specifying the number of classes as 2 (bankruptcy and non-bankruptcy).

When the model is ready, we can calculate the training error on the testData.

The labelAndPreds constant will hold the result of a map() transformation. This transformation will return for each row a tuple. A tuple will contain the expected value for a testData row (point.label) and the predicted value (prediction). The prediction value is obtained by using the model and the test data (point.features).

The last line computes the training error using a filter() transformation and a count() action. The filter transformation will keep only those tuples for which the predicted and the expected value are different. In Scala syntax _1 and _2 are used to access the first and the second element of a tuple. By counting the tuples with a false prediction and dividing this by the number of elements in testData we will have the training error. You should have something like this:

GitHub repository for this project: https://github.com/technobium/spark-logistic-regression

Mustafa Elsheikh took the time and wrote the Python version of the demo: https://github.com/elsheikhmh/workspace/blob/master/Qualitative_Bankruptcy.py. Thank you very much Mustafa.


In this tutorial you have seen how Apache Spark can be used for machine learning tasks like logistic regression. Although this was a standalone Scala shell demo, the power of Spark lies in the in-memory parallel processing capacity.

Spark is currently the most active open source project in big data and has been rapidly gaining traction over the past few years. This survey of over 2100 respondents further validates the wide variety of use cases and environments where it is being deployed.


“Learning Spark” by Holden Karau, Andy Konwinski, Patrick Wendell and Matei Zaharia, O’Reilly Media 2015

Lichman, M. (2013). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science







Add a Comment

Your email address will not be published. Required fields are marked *