import java.util.*; public class KMeansClustering { double data[][] = {{2, 3}, {1, 9}, {0, 6}, {8, 9}, {9, 6}, {6, 3}}; double centroids[][] = {{2,3}, {1,9}}; double getDistance(double[] datum, double[] centroid) { double d = 0.0; for (int i = 0; i < datum.length; i++) { d += Math.pow(datum[i] - centroid[i], 2); } return(Math.sqrt(d)); } int getClosestCentroid(double[] datum) { double min = Double.MAX_VALUE; int k = -1; for (int i = 0; i < centroids.length; i++) { double d = getDistance(datum, centroids[i]); if (d < min) { k = i; min = d; } } return(k); } void printDatum(double[] datum) { Vector v = new Vector (); for (int j = 0; j < datum.length; j++) v.add(new Double(datum[j])); System.out.println(v); } void printCentroids() { for (int i = 0; i < centroids.length; i++) printDatum(centroids[i]); System.out.println("-------------------"); } void run(int nEpochs) { for (int epoch = 0; epoch < nEpochs; epoch++) { double newCentroids[][] = {{0,0},{0,0}}; int n[] = {0,0}; printCentroids(); for (int i = 0; i < data.length; i++) { int k = getClosestCentroid(data[i]); newCentroids[k][0] += data[i][0]; newCentroids[k][1] += data[i][1]; n[k]++; } newCentroids[0][0] /= n[0]; newCentroids[0][1] /= n[0]; newCentroids[1][0] /= n[1]; newCentroids[1][1] /= n[1]; centroids = newCentroids; } } public static void main(String args[]) { int n = 5; if (args.length > 0) n = Integer.parseInt(args[0]); new KMeansClustering().run(n); } }