softmax

来源:互联网 发布:mysql 表 字段关联 编辑:程序博客网 时间:2024/06/11 14:03
package com.gridnt.offline;

public class offlineCal {
    static final int K = 2;// class num
    static int M = 9;// train numbers
    static int N = 4;// features
    static final double E = 2.718281828;
    static double x[][] = new double[][] {
            { 1, 47, 76, 24 }, // include x0=1
            { 1, 46, 77, 23 }, { 1, 48, 74, 22 }, { 1, 34, 76, 21 }, { 1, 35, 75, 24 },
            { 1, 34, 77, 25 }, { 1, 55, 76, 21 }, { 1, 56, 74, 22 }, { 1, 55, 72, 22 }, };

    static int y[] = new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, };

    static double theta[][] = new double[][] { { 0.3, 0.3, 0.01, 0.01 }, { 0.5, 0.5, 0.01, 0.01 } }; // include

    // theta0

    static double Efun(double p) {
        return Math.pow(E, p);
    }

    static double phi[] = new double[K + 1];
    static double H[][] = new double[M][K + 1];

    static void calH() {
        for (int i = 0; i < M; i++) {
            double p1 = 1;
            for (int j = 0; j < K; j++) {
                double sum = 0;
                for (int n = 0; n < N; n++) {
                    sum += (theta[j][n] * x[i][n]);
                }
                phi[j] = Efun(sum);
                p1 += Efun(sum);
            }
            H[i][0] = phi[0] /= p1;
            H[i][1] = phi[1] /= p1;
            H[i][2] = phi[2] = 1 - phi[0] - phi[1];
        }
    }

    static double calcLikelyHood() {
        calH();
        double likelihood = 0;
        for (int i = 0; i < M; i++) {
            likelihood += Math.log(H[i][y[i] - 1]);
        }
        System.err.println(likelihood);
        return likelihood;
    }

    // theta0
    public static void main(String[] args) {

        for (int w = 0; w < 100000; w++) {
            for (int i = 0; i < M; i++) {
                double p1 = 1;
                for (int j = 0; j < K; j++) {
                    double sum = 0;
                    for (int n = 0; n < N; n++) {
                        sum += (theta[j][n] * x[i][n]);
                    }
                    phi[j] = Efun(sum);
                    p1 += Efun(sum);
                }
                phi[0] /= p1;
                phi[1] /= p1;
                // phi[2] = 1 - phi[0] - phi[1];
                for (int a = 0; a < K; a++) {
                    for (int j = 0; j < N; j++) {
                        theta[a][j] += 0.001 * x[i][j] * (((y[i] == a + 1) ? 1 : 0) - phi[a]);
                    }
                }
            }
            calcLikelyHood();
        }
        calH();
        System.err.println(theta);
    }
}