Mushroom Poisonous Prediction (Decision Tree) in CSharp
Revision as of 21:02, 13 May 2009 by Emgucv (talk | contribs) (New page: '''This example requires Emgu CV 2.0.0.0 which is available from SVN or Nightly build only''' Emgu CV version >= 2.0 contains working DTree class (...)
This example requires Emgu CV 2.0.0.0 which is available from SVN or Nightly build only
Emgu CV version >= 2.0 contains working DTree class (it fixes a bug in version 1.5.0.1's implementation). The following code is an unit test for the OpenCV DTree class, which is also a .Net port of the OpenCV mushroom.exe example.
using System.Drawing;
using Emgu.CV;
using Emgu.CV.Structure;
using Emgu.CV.ML;
using Emgu.CV.ML.Structure;
...
private void ReadMushroomData(out Matrix<float> data, out Matrix<float> response)
{
string[] rows = System.IO.File.ReadAllLines("agaricus-lepiota.data");
int varCount = rows[0].Split(',').Length - 1;
data = new Matrix<float>(rows.Length, varCount);
response = new Matrix<float>(rows.Length, 1);
int count = 0;
foreach (string row in rows)
{
string[] values = row.Split(',');
Char c = System.Convert.ToChar(values[0]);
response[count, 0] = System.Convert.ToInt32(c);
for (int i = 1; i < values.Length; i++)
data[count, i - 1] = System.Convert.ToByte(System.Convert.ToChar(values[i]));
count++;
}
}
[Test]
public void TestDTreesMushroom()
{
Matrix<float> data, response;
ReadMushroomData(out data, out response);
//Use the first 80% of data as training sample
int trainingSampleCount = (int)(data.Rows * 0.8);
Matrix<Byte> varType = new Matrix<byte>(data.Cols + 1, 1);
varType.SetValue((byte)MlEnum.VAR_TYPE.CATEGORICAL); //the data is categorical
Matrix<byte> sampleIdx = new Matrix<byte>(data.Rows, 1);
using (Matrix<byte> sampleRows = sampleIdx.GetRows(0, trainingSampleCount, 1))
sampleRows.SetValue(255);
float[] priors = new float[] {1, 0.5f};
GCHandle priorsHandle = GCHandle.Alloc(priors, GCHandleType.Pinned);
MCvDTreeParams param = new MCvDTreeParams();
param.maxDepth = 8;
param.minSampleCount = 10;
param.regressionAccuracy = 0;
param.useSurrogates = true;
param.maxCategories = 15;
param.cvFolds = 10;
param.use1seRule = true;
param.truncatePrunedTree = true;
param.priors = priorsHandle.AddrOfPinnedObject();
using (DTree dtree = new DTree())
{
bool success = dtree.Train(
data,
Emgu.CV.ML.MlEnum.DATA_LAYOUT_TYPE.ROW_SAMPLE,
response,
null,
sampleIdx,
varType,
null,
param);
if (!success) return;
double trainDataCorrectRatio = 0;
double testDataCorrectRatio = 0;
for (int i = 0; i < data.Rows; i++)
{
using (Matrix<float> sample = data.GetRow(i))
{
double r = dtree.Predict(sample, null, false).value;
r = Math.Abs(r - response[i, 0]);
if (r < 1.0e-5)
{
if (i < trainingSampleCount)
trainDataCorrectRatio++;
else
testDataCorrectRatio++;
}
}
}
trainDataCorrectRatio /= trainingSampleCount;
testDataCorrectRatio /= (data.Rows - trainingSampleCount);
Trace.WriteLine(String.Format("Prediction accuracy for training data :{0}%", trainDataCorrectRatio*100));
Trace.WriteLine(String.Format("Prediction accuracy for test data :{0}%", testDataCorrectRatio*100));
}
priorsHandle.Free();
}
The result of running this unit test:
Prediction accuracy for training data :99.8769041390983% Prediction accuracy for test data :99.2615384615385%
That's a really good prediction rate. A big thanks to OpenCV developers.