Sei sulla pagina 1di 5

// STIS Assignment 2

// PRASHANT GUPTA 2010IPG-071


import java.util.Scanner;
import java.io.*;

public class Classify {


int correctclass = 0; // The correct class of the current digit
double [] truematches = new double [10]; // The counter for the
number of correct classifications
double totaltruematch=0;
double [] testdigitcount = new double [10];
double count = 0; // A counter for the number of digits
double smooth = 1;
//total datasets in the traininglabels
double total = 5000;
//prior probabilities P(class)
double [] priorprobs = new double[10];
//Store the likelihoods of each digit(0-9) class of size 28x28
double [][][] digitlikelihood = new double [10][28][28];
//Total number of occurences of a each digit (0-9) class in the
traininglabels file
double [] digitCount = new double [10];
//new digit data from test dataset
double [][] testdigit = new double[28][28];

//Initialisation with training data


public Classify() throws IOException {
smoothvalues();
trainLikelihood();
}

public static void main(String[] args) throws IOException


{
Classify finddigit = new Classify();
finddigit.priorprob();
finddigit.testdigit();
}

// to avoid zero probabilities


public void smoothvalues() {
for(int i =0; i < 10; i++) {
for(int j =0; j < 28; j++) {

for(int k =0; k < 28; k++) {


digitlikelihood[i][j][k] = smooth;
}
}
}
}

public void trainLikelihood() throws IOException {


FileReader inputStream = new
FileReader("datafiles/trainingimages");
Scanner labels = new Scanner(new
File("datafiles/traininglabels"));
int c;
int index = labels.nextInt();
int flag=0;
for(int i = 0; true; i++){
for(int j =0; j < 29; j++) {
c = inputStream.read();
if(c==-1){
digitCount[index]++;
flag=1;
break ;
}
else if(c== 35 || c==43) //ascii values of # and +
{
digitlikelihood[index][i%28][j]++;
}
}
if(flag==1) break;
if(i%28 == 0 && i != 0) {
digitCount[index]++;
if (!labels.hasNextInt())
break;
else
index = labels.nextInt();
}
}
}

// Calculate prior probabilities and likelihoods

public void priorprob() {


System.out.println("Total no. of occurences of each digit in
training dataset\n");
for(int i =0; i < 10; i++) {
priorprobs[i]= (double)(digitCount[i])/(total);
for(int j =0; j < 28; j++) {
for(int k =0; k < 28; k++) {
digitlikelihood[i][j][k] =
(double)(digitlikelihood[i][j][k])/(digitCount[i]+(smooth));
// add smoothness to avoid zeros.
}
}
System.out.println("Digit "+ i+ " : "+digitCount[i]+
"\n");
}
}

public void testdigit() throws IOException


{
FileReader inputDigit = new
FileReader("datafiles/testimages");
Scanner correctDigit = new Scanner(new
File("datafiles/testlabels"));
int c;
outerloop:
for(int i = 0; true; i++) {
if(i %28 == 0 && i != 0) {
if (!correctDigit.hasNextInt())
break outerloop;
else
{
correctclass = correctDigit.nextInt();
testdigitcount[correctclass]++;
}
}
for(int j =0; j < 29; j++) {
c =

inputDigit.read();

switch (c) {
case -1:
checkClass(c);

break outerloop;
case 35:
case 43:
testdigit[i%28][j] = 1;
break;
case 32:
testdigit[i%28][j] = 0;
break;
}
}
if(i %28==0 && i !=0)
checkClass(1);
}
}

public void checkClass( int input)


{
double[] probabilities = new double[10];
int best = 0;
for(int i =0; i < 10; i++) {
probabilities[i]= probabilities[i] +
(Math.log(priorprobs[i]) );
for(int j =0; j < 28; j++) {
for(int k =0; k < 28; k++) {
if(testdigit[j][k]==0)
probabilities[i] = probabilities[i] + (
Math.log(1-digitlikelihood[i][j][k]));
else
probabilities[i] = probabilities[i] + (
Math.log(digitlikelihood[i][j][k]));
}
}
if(probabilities[i] > probabilities[best])
best = i;
}
if(best == correctclass) {
truematches[best]++;
totaltruematch++;
}
count++;

if(input ==-1) {
answer();
}
}

public void answer() {


System.out.println("Total no. of occurences of each digit in
test dataset\n");
for(int i=0;i<=9;i++)
{
System.out.println("Digit"+i+":"+testdigitcount[i]+"\n");
}
System.out.println("\n Classification rate (percentage) for
each digit : \n");
printClassificationRate();
System.out.println("\n Overall classification rate \n");
System.out.println("Count: " +totaltruematch+" / "+count);
System.out.println("Percentage: "+
(float)(totaltruematch/count)*100);
}

public void printClassificationRate(){


for(int i=0;i<=9;i++)
{
System.out.println("Digit "+ i+ " :
"+(float)(truematches[i]/testdigitcount[i])*100+ "\n");
}
}
}

Potrebbero piacerti anche