Logistic regression using Apache Mahout
|Logistic regression is a supervised learning algorithm used to classify input data into a categories. If we have two possible categories, then we are using binary or binomial logistic regression and if we have more than three categories we are using multinomial logistic regression. For the binary logistic regression, the algorithm will find a mathematical function which best fits the training data. This function is the sigmoid function which takes values between 1 and 0. The classification algorithm will use the trained model function and will return the probability for a new input data to be in a category or another.
In this tutorial we will use the logistic regression algorithm to solve a fraud detection problem. In the following sections we will define the problem, explain how will logistic regression help solve the problem and we will see a how to build a small working example using Java. Basic Java programming knowledge is required for this tutorial.
Detect car mileage fraud using logistic regression
Mileage fraud, tachograph fraud, odometer fraud, “busting miles” or “clocking” mean fraudulently manipulating the mileage indicator of a vehicle in order to increase its market value. Of course there are many tips an tricks you can use to tell if this car has been subject to this type of fraud. For this example we will use a common sense check in order to detect the fraud: if the car has too few miles/kilometers for its age, then we can suspect a manipulation of the mileage indicator. Logistic regression is the machine learning algorithm we can use to automate the decision if a car is suspect of mileage fraud or not. We will use a small data set in order to train the logistic regression algorithm. The output of the training step will be a mathematical function, which we will to predict the fraud probability for new input data.
In production environments, the training is usually executed against a larger data set (millions of entries). Also, the input data is split in two: a part on which the model will be trained, and a part on which the quality of the trained model will be verified. For instance 60% of the initial data will be used for training the model and 40% to test the model.
In our example we will have an input file with training data. Each line will carry the car type, car age, car mileage and in the last column we will have a binary value, which tells if the mileage was manipulated (value 1) or not (value 0). I used generic car types, but you can use real life car models. As you can see in the first line, a small car with a small mileage is not suspected to be manipulated. If you look at the last line for instance, a sport car, having the same age and the same mileage can be suspected of manipulation. The data in the following table is fictive and real life decisions should not be based upon it.
model | age | mileage | result |
---|---|---|---|
small | 10 | 100000 | 0 |
small | 10 | 200000 | 0 |
small | 8 | 30000 | 1 |
small | 3 | 10000 | 1 |
small | 5 | 10000 | 1 |
medium | 6 | 60000 | 0 |
medium | 4 | 10000 | 1 |
medium | 4 | 200000 | 0 |
medium | 5 | 50000 | 1 |
family | 2 | 60000 | 0 |
family | 5 | 10000 | 1 |
family | 4 | 200000 | 0 |
family | 7 | 70000 | 1 |
family | 1 | 20000 | 0 |
family | 2 | 10000 | 1 |
sport | 6 | 50000 | 1 |
sport | 4 | 100000 | 0 |
sport | 2 | 20000 | 1 |
sport | 3 | 30000 | 1 |
sport | 10 | 5000 | 1 |
sport | 10 | 100000 | 1 |
To automate the decision we will use the OnlineLogisticRegression algorithm from Apache Mahout. The input of the algorithm will be an array of Observation objects. Each Observation contains a vector with the car detail (type, age mileage) and the actual category according to the input data (1 manipulate or 0 not manipulated). The first element of the vector is the intercept term, which is important in order to obtain a accurate model and which has the value 1. You can see the intercept term in action also in simple linear regression. The model is trained 30 times and each 10th iteration we check its quality against the same input data set. If we had much more data available, we would have used a subset of the data for model quality check. The final step will be to use the model in order to predict the fraud probability for car data not present in the training data set.
Logistic regression Java project
Prerequisites:
- Linux or Mac
- Java 1.7
- Apache Maven 3
Create the Maven project:
1 2 3 4 5 |
mvn archetype:generate \ -DarchetypeGroupId=org.apache.maven.archetypes \ -DgroupId=com.technobium \ -DartifactId=mahout-logistic-regression \ -DinteractiveMode=false |
Rename the default created App class to LogisticRegression using the following command:
1 2 |
mv mahout-logistic-regression/src/main/java/com/technobium/App.java \ mahout-logistic-regression/src/main/java/com/technobium/LogisticRegression.java |
Add the Mahout and SLF4J libraries to this project:
1 2 |
cd mahout-logistic-regression nano pom.xml |
Add the following lines to the dependencies section:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
<dependencies> ... <dependency> <groupId>org.apache.mahout</groupId> <artifactId>mahout-core</artifactId> <version>0.9</version> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-simple</artifactId> <version>1.7.7</version> </dependency> </dependencies> |
In the same file, after the dependencies section, add the following configuration which makes sure the code is compiled using Java 1.7
1 2 3 4 5 6 7 8 9 10 11 |
<build> <plugins> <plugin> <artifactId>maven-compiler-plugin</artifactId> <configuration> <source>1.7</source> <target>1.7</target> </configuration> </plugin> </plugins> </build> |
Create an input folder and copy the file containing the training data, inputData.csv.
1 |
mkdir input |
Edit the ClusteringDemo class file and add the following code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
package com.technobium; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; import org.apache.mahout.classifier.evaluation.Auc; import org.apache.mahout.classifier.sgd.L1; import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; public class LogisticRegression { public static void main(String[] args) { LogisticRegression logisticRegression = new LogisticRegression(); // Load the input data List<Observation> trainingData = logisticRegression .parseInputFile("input/inputData.csv"); // Train a model OnlineLogisticRegression olr = logisticRegression.train(trainingData); // Test the model logisticRegression.testModel(olr); } public List<Observation> parseInputFile(String inputFile) { List<Observation> result = new ArrayList<Observation>(); BufferedReader br = null; String line = ""; try { // Load the file which contains training data br = new BufferedReader(new FileReader(new File(inputFile))); // Skip the first line which contains the header values line = br.readLine(); // Prepare the observation data while ((line = br.readLine()) != null) { String[] values = line.split(","); result.add(new Observation(values)); } } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } finally { if (br != null) { try { br.close(); } catch (IOException e) { e.printStackTrace(); } } } return result; } public OnlineLogisticRegression train(List<Observation> trainData) { OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 4, new L1()); // Train the model using 30 passes for (int pass = 0; pass < 30; pass++) { for (Observation observation : trainData) { olr.train(observation.getActual(), observation.getVector()); } // Every 10 passes check the accuracy of the trained model if (pass % 10 == 0) { Auc eval = new Auc(0.5); for (Observation observation : trainData) { eval.add(observation.getActual(), olr.classifyScalar(observation.getVector())); } System.out.format( "Pass: %2d, Learning rate: %2.4f, Accuracy: %2.4f\n", pass, olr.currentLearningRate(), eval.auc()); } } return olr; } void testModel(OnlineLogisticRegression olr) { Observation newObservation = new Observation(new String[] { "family", "10", "100000", "0" }); Vector result = olr.classifyFull(newObservation.getVector()); System.out.println("------------- Testing -------------"); System.out.format("Probability of not fraud (0) = %.3f\n", result.get(0)); System.out.format("Probability of fraud (1) = %.3f\n", result.get(1)); } class Observation { private DenseVector vector = new DenseVector(4); private int actual; public Observation(String[] values) { ConstantValueEncoder interceptEncoder = new ConstantValueEncoder( "intercept"); StaticWordValueEncoder encoder = new StaticWordValueEncoder( "feature"); interceptEncoder.addToVector("1", vector); vector.set(0, Double.valueOf(values[1])); // Feature scaling, divide mileage by 10000 vector.set(1, Double.valueOf(values[2]) / 10000); encoder.addToVector(values[0], vector); this.actual = Integer.valueOf(values[3]); } public Vector getVector() { return vector; } public int getActual() { return actual; } } } |
Run the class by using the following command:
1 2 |
mvn compile mvn exec:java -Dexec.mainClass="com.technobium.LogisticRegression" |
This should output something like this:
1 2 3 4 5 6 |
Pass: 0, Learning rate: 0.1759, Accuracy: 0.9615 Pass: 10, Learning rate: 0.0511, Accuracy: 0.9712 Pass: 20, Learning rate: 0.0303, Accuracy: 0.9712 ------------- Testing ------------- Probability of not fraud (0) = 0.090 Probability of fraud (1) = 0.910 |
We used for testing new data, not present in the training data: a family car which is 10 years old and was used for 100000 kilometers. For this input, the algorithm tells us that there is 91% chances that the mileage of the car was manipulated. The decision was based on the data given as input during the training phase.
Conclusion
Typical usages for logistic regression are fraud detection, manufacturing error detection, weather prediction, mail filtering (spam or ham) or in medicine for case classification. Very close to linear regression this classification algorithm is one of the most used machine learning algorithms.
GitHub repository for this project: https://github.com/technobium/mahout-logistic-regression/
References
http://blog.trifork.com/2014/02/04/an-introduction-to-mahouts-logistic-regression-sgd-classifier/
“Mahout in action”, Owen et. al., Manning Pub. 2011 – http://manning.com/owen/