Logistic regression using Apache Mahout

Logistic regression is a supervised learning algorithm used to classify input data into a categories. If we have two possible categories, then we are using binary or binomial logistic regression and if we have more than three categories we are using multinomial logistic regression. For the binary logistic regression, the algorithm will find a mathematical function which best fits the training data. This function is the sigmoid function which takes values between 1 and 0. The classification algorithm will use the trained model function and will return the probability for a new input data to be in a category or another.

In this tutorial we will use the logistic regression algorithm to solve a fraud detection problem. In the following sections we will define the problem, explain how will logistic regression help solve the problem and we will see a how to build a small working example using Java. Basic Java programming knowledge is required for this tutorial.

Detect car mileage fraud using logistic regression

Mileage fraud, tachograph fraud, odometer fraud, “busting miles” or “clocking” mean fraudulently manipulating the mileage indicator of a vehicle in order to increase its market value. Of course there are many tips an tricks you can use to tell if this car has been subject to this type of fraud. For this example we will use a common sense check in order to detect the fraud: if the car has too few miles/kilometers for its age, then we can suspect a manipulation of the mileage indicator. Logistic regression is the machine learning algorithm we can use to automate the decision if a car is suspect of mileage fraud or not. We will use a small data set in order to train the logistic regression algorithm. The output of the training step will be a mathematical function, which we will to predict the fraud probability for  new input data.

In production environments, the training is usually executed against a larger data set (millions of entries). Also, the input data is split in two: a part on which the model will be trained, and a part on which the quality of the trained model will be verified. For instance 60% of the initial data will be used for training the model and 40% to test the model.

In our example we will have an input file with training data. Each line will carry the car type, car age, car mileage and in the last column we will have a binary value, which tells if the mileage was manipulated (value 1) or not (value 0). I used generic car types, but you can use real life car models. As you can see in the first line, a small car with a small mileage is not suspected to be manipulated. If you look at the last line for instance, a sport car, having the same age and the same mileage can be suspected of manipulation. The data in the following table is fictive and real life decisions should not be based upon it.

modelagemileageresult
small101000000
small102000000
small8300001
small3100001
small5100001
medium6600000
medium4100001
medium42000000
medium5500001
family2600000
family5100001
family42000000
family7700001
family1200000
family2100001
sport6500001
sport41000000
sport2200001
sport3300001
sport1050001
sport101000001

To automate the decision we will use the OnlineLogisticRegression algorithm from Apache Mahout. The input of the algorithm will be an array of Observation objects. Each Observation contains a vector with the car detail (type, age mileage) and the actual category according to the input data (1 manipulate or 0 not manipulated). The first element of the vector is the intercept term, which is important in order to obtain a accurate model and which has the value 1. You can see the intercept term in action also in simple linear regression. The model is trained 30 times and each 10th iteration we check its quality against the same input data set. If we had much more data available, we would have used a subset of the data for model quality check. The final step will be to use the model in order to predict the fraud probability for car data not present in the training data set.

Logistic regression Java project

Prerequisites:

Create the Maven project:

mvn archetype:generate \
    -DarchetypeGroupId=org.apache.maven.archetypes \
    -DgroupId=com.technobium \
    -DartifactId=mahout-logistic-regression \
    -DinteractiveMode=false

Rename the default created App class to LogisticRegression using the following command:

mv mahout-logistic-regression/src/main/java/com/technobium/App.java \
   mahout-logistic-regression/src/main/java/com/technobium/LogisticRegression.java

Add the Mahout and SLF4J libraries to this project:

cd mahout-logistic-regression
nano pom.xml

Add the following lines to the dependencies section:

<dependencies>
    ...
    <dependency>
        <groupId>org.apache.mahout</groupId>
        <artifactId>mahout-core</artifactId>
        <version>0.9</version>
    </dependency>
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-simple</artifactId>
        <version>1.7.7</version>
    </dependency>
</dependencies>

In the same file, after the dependencies section, add the following configuration which makes sure the code is compiled using Java 1.7

<build>
    <plugins>
	<plugin>
	    <artifactId>maven-compiler-plugin</artifactId>
		<configuration>
		    <source>1.7</source>
		    <target>1.7</target>
		</configuration>
	</plugin>
    </plugins>
</build>

Create an input folder and copy the file containing the training data, inputData.csv.

mkdir input

Edit the ClusteringDemo class file and add the following code:

package com.technobium;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;

public class LogisticRegression {
	public static void main(String[] args) {
		LogisticRegression logisticRegression = new LogisticRegression();

		// Load the input data
		List<Observation> trainingData = logisticRegression
				.parseInputFile("input/inputData.csv");

		// Train a model
		OnlineLogisticRegression olr = logisticRegression.train(trainingData);

		// Test the model
		logisticRegression.testModel(olr);
	}

	public List<Observation> parseInputFile(String inputFile) {
		List<Observation> result = new ArrayList<Observation>();
		BufferedReader br = null;
		String line = "";
		try {
			// Load the file which contains training data
			br = new BufferedReader(new FileReader(new File(inputFile)));
			// Skip the first line which contains the header values
			line = br.readLine();
			// Prepare the observation data
			while ((line = br.readLine()) != null) {
				String[] values = line.split(",");
				result.add(new Observation(values));
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			if (br != null) {
				try {
					br.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
			}
		}
		return result;
	}

	public OnlineLogisticRegression train(List<Observation> trainData) {
		OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 4,
				new L1());
		// Train the model using 30 passes
		for (int pass = 0; pass < 30; pass++) {
			for (Observation observation : trainData) {
				olr.train(observation.getActual(), observation.getVector());
			}
			// Every 10 passes check the accuracy of the trained model
			if (pass % 10 == 0) {
				Auc eval = new Auc(0.5);
				for (Observation observation : trainData) {
					eval.add(observation.getActual(),
							olr.classifyScalar(observation.getVector()));
				}
				System.out.format(
						"Pass: %2d, Learning rate: %2.4f, Accuracy: %2.4f\n",
						pass, olr.currentLearningRate(), eval.auc());
			}
		}
		return olr;
	}

	void testModel(OnlineLogisticRegression olr) {
		Observation newObservation = new Observation(new String[] { "family",
				"10", "100000", "0" });
		Vector result = olr.classifyFull(newObservation.getVector());

		System.out.println("------------- Testing -------------");
		System.out.format("Probability of not fraud (0) = %.3f\n",
				result.get(0));
		System.out.format("Probability of fraud (1)     = %.3f\n",
				result.get(1));
	}

	class Observation {
		private DenseVector vector = new DenseVector(4);
		private int actual;

		public Observation(String[] values) {
			ConstantValueEncoder interceptEncoder = new ConstantValueEncoder(
					"intercept");
			StaticWordValueEncoder encoder = new StaticWordValueEncoder(
					"feature");

			interceptEncoder.addToVector("1", vector);
			vector.set(0, Double.valueOf(values[1]));
			// Feature scaling, divide mileage by 10000
			vector.set(1, Double.valueOf(values[2]) / 10000);
			encoder.addToVector(values[0], vector);

			this.actual = Integer.valueOf(values[3]);
		}

		public Vector getVector() {
			return vector;
		}

		public int getActual() {
			return actual;
		}
	}
}

Run the class by using the following command:

mvn compile
mvn exec:java -Dexec.mainClass="com.technobium.LogisticRegression"

This should output something like this:

Pass:  0, Learning rate: 0.1759, Accuracy: 0.9615
Pass: 10, Learning rate: 0.0511, Accuracy: 0.9712
Pass: 20, Learning rate: 0.0303, Accuracy: 0.9712
------------- Testing -------------
Probability of not fraud (0) = 0.090
Probability of fraud (1)     = 0.910

We used for testing new data, not present in the training data: a family car which is 10 years old and was used for 100000 kilometers. For this input, the algorithm tells us that there is 91% chances that the mileage of the car was manipulated. The decision was based on the data given as input during the training phase.

Conclusion

Typical usages for logistic regression are fraud detection, manufacturing error detection, weather prediction, mail filtering (spam or ham) or in medicine for case classification. Very close to linear regression this classification algorithm is one of the most used machine learning algorithms.

GitHub repository for this project: https://github.com/technobium/mahout-logistic-regression/

References

http://blog.trifork.com/2014/02/04/an-introduction-to-mahouts-logistic-regression-sgd-classifier/

http://archive.cloudera.com/cdh4/cdh/4/mahout/mahout-core/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.html

“Mahout in action”, Owen et. al., Manning Pub. 2011 – http://manning.com/owen/

Add a Comment

Your email address will not be published. Required fields are marked *