나이브 베이지안 분류 기법(Naive Bayesian classifier)은 학습, 구현 과정이 쉽지만 성능도 잘 나오는 분류 방법입니다.
이 방법을 쓰려면 모든 속성이 서로 독립적이라는 가정이 있어야 합니다.
간략한 설명을 보시려면 http://bcho.tistory.com/1010 링크를 확인하시기 바랍니다.
이번 포스트에서는 나이브 베이지안 분류기를 자바로 구현하고자 하였습니다.
또한 구현한 모델(오브젝트)을 파일로 저장하고, 추후 읽어와서 쓸 수 있게 하여 학습 데이터를 매번 학습할 필요가 없도록 했습니다.
아래는 나이브 베이지안을 자바로 구현한 예제코드 전문입니다.
깃허브 파일 링크 : [1]
import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Set; public class NaiveBayes { public static ArrayList<String> getData(String fileName) { ArrayList<String> data = new ArrayList<String>(); BufferedReader inputStream = null; try { inputStream = new BufferedReader(new FileReader(fileName)); String l; while ((l = inputStream.readLine()) != null) data.add(l); } catch (IOException e) { System.err.println("getData: "+e.getMessage()); System.exit(1); } return data; } public static Double[] getGaussianModel(ArrayList<Double> values) { Double[] g = new Double[2]; double mu = 0d; double ss = 0d; for (int i = 0; i < values.size(); i++) mu += values.get(i); mu /= values.size(); for (int i = 0; i < values.size(); i++) ss += (values.get(i)-mu)*(values.get(i)-mu); ss /= values.size(); g[0] = mu; g[1] = Math.sqrt(ss); return g; } public static HashMap<String,Double[]> getNumerical(ArrayList<String> data, int col) { HashMap<String,Double[]> prb = new HashMap<String,Double[]>(); // value = {mu, sigma} HashMap<String,ArrayList<Double>> classes = new HashMap<String,ArrayList<Double>>(); for (int i = 0; i < data.size(); i++) { String[] arr = data.get(i).split(","); ArrayList<Double> values = new ArrayList<Double>(); if (classes.containsKey(arr[arr.length-1])) values = classes.get(arr[arr.length-1]); values.add(Double.parseDouble(arr[col])); classes.put(arr[arr.length-1], values); } Iterator<String> itr = classes.keySet().iterator(); while(itr.hasNext()) { String key = itr.next(); prb.put(key, getGaussianModel(classes.get(key))); } return prb; } public static HashMap<String,Double> getCategorical(ArrayList<String> data, int col) { HashMap<String,Double> prb = new HashMap<String,Double>(); HashMap<String,Integer> counts = new HashMap<String,Integer>(); int countAll = 0; for (int i = 0; i < data.size(); i++) { String[] arr = data.get(i).split(","); int count = 0; if (counts.containsKey(arr[col]+","+arr[arr.length-1])) count = counts.get(arr[col]+","+arr[arr.length-1]); count++; countAll++; counts.put(arr[col]+","+arr[arr.length-1], count); } for (int i = 0; i < counts.keySet().size(); i++) { String key = (String) counts.keySet().toArray()[i]; prb.put(key, (double)counts.get(key) / (double)countAll); } return prb; } public static HashMap<String,Double> getClass(ArrayList<String> data) { HashMap<String,Double> prb = new HashMap<String,Double>(); HashMap<String,Integer> counts = new HashMap<String,Integer>(); int countAll = 0; for (int i = 0; i < data.size(); i++) { String[] arr = data.get(i).split(","); int count = 0; if (counts.containsKey(arr[arr.length-1])) count = counts.get(arr[arr.length-1]); count++; countAll++; counts.put(arr[arr.length-1], count); } for (int i = 0; i < counts.keySet().size(); i++) { String key = (String)counts.keySet().toArray()[i]; int value = counts.get(key); prb.put(key, (double)value/(double)countAll); } return prb; } public static void printPrb(HashMap<String,Double> prb) { System.out.println("printPrb"); for (int i = 0; i < prb.keySet().size(); i++) { String key = (String) prb.keySet().toArray()[i]; System.out.println(key+"\t"+prb.get(key)); } System.out.println(); } public static void printPrb2(HashMap<String,Double[]> prb) { System.out.println("printPrb"); for (int i = 0; i < prb.keySet().size(); i++) { String key = (String) prb.keySet().toArray()[i]; System.out.println(key+"\t"+prb.get(key)[0]+"\t"+prb.get(key)[1]); } System.out.println(); } public static double getGassusianValue(Double[] ms, double val) { return 1.0/(ms[1]*Math.sqrt(2*Math.PI))*Math.exp(-(val-ms[0])*(val-ms[0]) / (2.0*ms[1]*ms[1])); } public static void printAttribute(String str) { String[] arr = str.split(","); for (int i = 0; i < arr.length; i++) System.out.print(arr[i]+" "); System.out.println(); } public static void getTest(ArrayList<String> test, HashMap<String,Double> prb_class, HashMap<String,Object>[] prb_attributes, int[] categoric, int[] numeric, int N) { Set<String> classes = prb_class.keySet(); for (int i = 0; i < test.size(); i++) { double yesOrNo = Double.MIN_VALUE; String isYes = ""; String[] attr = test.get(i).split(","); // PRINT ATTRIBUTES printAttribute(test.get(i)); for (int c = 0; c < classes.size(); c++) { double cond_prb = 1d; String cls = (String)classes.toArray()[c]; for (int j = 0; j < categoric.length; j++) { int col = categoric[j]; StringBuffer key = new StringBuffer(); key.append(attr[col]); key.append(","); key.append(cls); HashMap<String,Object> temp_prb = prb_attributes[col]; if (temp_prb.containsKey(key.toString())) cond_prb *= (Double)temp_prb.get(key.toString()); else cond_prb = 1.0 / ((double) N+temp_prb.keySet().size()); // simply smoothing if (cond_prb > yesOrNo) { yesOrNo = cond_prb; isYes = cls; } // TEST // System.out.println(key.toString()+"\t"+(Double)temp_prb.get(key.toString())); // System.out.println(); } for (int j = 0; j < numeric.length; j++) { int col = numeric[j]; StringBuffer key = new StringBuffer(); key.append(cls); HashMap<String,Object> temp_prb = prb_attributes[col]; Double[] ms = (Double[]) temp_prb.get(key.toString()); cond_prb *= getGassusianValue(ms, Double.parseDouble(attr[col])); if (cond_prb > yesOrNo) { yesOrNo = cond_prb; isYes = cls; } // TEST // System.out.println(key.toString()+"\t"+getGassusianValue(ms, Double.parseDouble(attr[col]))); // System.out.println(); } System.out.println(cls+"\t"+cond_prb); } System.out.println(isYes+"\n"); } } public static void saveModel(Object obj, String fileName) { try { FileOutputStream fout = new FileOutputStream(fileName); ObjectOutputStream oos = new ObjectOutputStream(fout); oos.writeObject(obj); } catch (IOException e) { System.err.println("saveModel: "+e.getMessage()); System.exit(1); } } public static Object getModel(String fileName) { Object obj = new Object(); try { FileInputStream fin = new FileInputStream(fileName); ObjectInputStream ios = new ObjectInputStream(fin); obj = ios.readObject(); } catch (IOException e) { System.err.println("getModel: "+e.getMessage()); System.exit(1); } catch (ClassNotFoundException e) { System.err.println("getModel: "+e.getMessage()); System.exit(1); } return obj; } public static void main(String[] args) { String fileName = "/home/spark/bigDataProcessing/module2/play_tennis.csv"; String fileName2 = "/home/spark/bigDataProcessing/module2/play_tennis_test.csv"; ArrayList<String> training = getData(fileName); ArrayList<String> test = getData(fileName2); int[] col_categoric = {0,3}; int[] col_numeric = {1,2}; // PROBABILITY of CLASSES HashMap<String,Double> prb_class = getClass(training); // printPrb(prb_class); // PROBABILITY of ATTRIBUTES HashMap[] prb_attributes; prb_attributes = new HashMap[4]; // For CATEGORIC ATTRIBUTES for (int i = 0; i < col_categoric.length; i++) { prb_attributes[col_categoric[i]] = getCategorical(training, col_categoric[i]); // printPrb(prb_attributes[col_categoric[i]]); } // For NUMERIC ATTRIBUTES for (int i = 0; i < col_numeric.length; i++) { prb_attributes[col_numeric[i]] = getNumerical(training, col_numeric[i]); // printPrb2(prb_attributes[col_numeric[i]]); } // TEST with MODEL OBJECT getTest(test, prb_class, prb_attributes, col_categoric, col_numeric, training.size()); // SAVE MODEL as FILE String file1 = "prb_class"; String file2 = "prb_attributes"; saveModel(prb_class, file1); saveModel(prb_attributes, file2); // GET MODEL from FILE HashMap<String,Double> prb_class2 = (HashMap<String,Double>) getModel(file1); HashMap[] prb_attributes2 = (HashMap[]) getModel(file2); // TEST from FILE OBJECT getTest(test, prb_class2, prb_attributes2, col_categoric, col_numeric, training.size()); } }
소스에서 saveModel, getModel 함수로 오브젝트를 파일로 쓰고, 읽는 예제를 포함했습니다.
public static void saveModel(Object obj, String fileName) { try { FileOutputStream fout = new FileOutputStream(fileName); ObjectOutputStream oos = new ObjectOutputStream(fout); oos.writeObject(obj); } catch (IOException e) { System.err.println("saveModel: "+e.getMessage()); System.exit(1); } } public static Object getModel(String fileName) { Object obj = new Object(); try { FileInputStream fin = new FileInputStream(fileName); ObjectInputStream ios = new ObjectInputStream(fin); obj = ios.readObject(); } catch (IOException e) { System.err.println("getModel: "+e.getMessage()); System.exit(1); } catch (ClassNotFoundException e) { System.err.println("getModel: "+e.getMessage()); System.exit(1); } return obj; }
메인 함수에서 하단 부에 파일로 읽어온 오브젝트를 가져오고 알맞게 형변환을 하여 모델을 테스트할 수 있습니다.
// GET MODEL from FILE HashMap<String,Double> prb_class2 = (HashMap<String,Double>) getModel(file1); HashMap[] prb_attributes2 = (HashMap[]) getModel(file2);
'노트정리 > 자바 JAVA' 카테고리의 다른 글
자바 함수에서 다른 종류의 오브젝트 반환하는 방법 (0) | 2017.10.05 |
---|---|
트위터 메시지를 태그 클라우드(tag cloud, word cloud) 만들기 예제 (0) | 2017.09.26 |
자바(Java)에서 HtmlUnit을 이용해 트위터 특정 사용자의 친구 목록 크롤링하는 예제 (0) | 2017.09.16 |
자바에서 json 튜토리얼 (1) | 2017.04.05 |
자바에서 파일 입출력하는 튜토리얼 (0) | 2017.03.30 |
JDBC 튜토리얼 사이트 소개와 예제 (0) | 2016.10.25 |
이클립스에서 the selection cannot be launched 에러 해결법. (0) | 2015.09.30 |
자바의 배열에서 C, C++과 가장 다른 점 (0) | 2014.02.13 |