/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.lmt;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SimpleLinearRegression;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class LogisticBase
extends Classifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = 168765678097825064L;
    protected Instances m_numericDataHeader;
    protected Instances m_numericData;
    protected Instances m_train;
    protected boolean m_useCrossValidation;
    protected boolean m_errorOnProbabilities;
    protected int m_fixedNumIterations;
    protected int m_heuristicStop = 50;
    protected int m_numRegressions = 0;
    protected int m_maxIterations;
    protected int m_numClasses;
    protected SimpleLinearRegression[][] m_regressions;
    protected static int m_numFoldsBoosting = 5;
    protected static final double Z_MAX = 3.0;
    private boolean m_useAIC = false;
    protected double m_numParameters = 0.0;
    protected double m_weightTrimBeta = 0.0;

    public LogisticBase() {
        this.m_fixedNumIterations = -1;
        this.m_useCrossValidation = true;
        this.m_errorOnProbabilities = false;
        this.m_maxIterations = 500;
        this.m_useAIC = false;
        this.m_numParameters = 0.0;
    }

    public LogisticBase(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities) {
        this.m_fixedNumIterations = numBoostingIterations;
        this.m_useCrossValidation = useCrossValidation;
        this.m_errorOnProbabilities = errorOnProbabilities;
        this.m_maxIterations = 500;
        this.m_useAIC = false;
        this.m_numParameters = 0.0;
    }

    public void buildClassifier(Instances data) throws Exception {
        this.m_train = new Instances(data);
        this.m_numClasses = this.m_train.numClasses();
        this.m_regressions = this.initRegressions();
        this.m_numRegressions = 0;
        this.m_numericData = this.getNumericData(this.m_train);
        this.m_numericDataHeader = new Instances(this.m_numericData, 0);
        if (this.m_fixedNumIterations > 0) {
            this.performBoosting(this.m_fixedNumIterations);
        } else if (this.m_useAIC) {
            this.performBoostingInfCriterion();
        } else if (this.m_useCrossValidation) {
            this.performBoostingCV();
        } else {
            this.performBoosting();
        }
        this.m_regressions = this.selectRegressions(this.m_regressions);
    }

    protected void performBoostingCV() throws Exception {
        int completedIterations = this.m_maxIterations;
        Instances allData = new Instances(this.m_train);
        allData.stratify(m_numFoldsBoosting);
        double[] error = new double[this.m_maxIterations + 1];
        for (int i = 0; i < m_numFoldsBoosting; ++i) {
            Instances train = allData.trainCV(m_numFoldsBoosting, i);
            Instances test = allData.testCV(m_numFoldsBoosting, i);
            this.m_numRegressions = 0;
            this.m_regressions = this.initRegressions();
            int iterations = this.performBoosting(train, test, error, completedIterations);
            if (iterations >= completedIterations) continue;
            completedIterations = iterations;
        }
        int bestIteration = this.getBestIteration(error, completedIterations);
        this.m_numRegressions = 0;
        this.performBoosting(bestIteration);
    }

    protected void performBoostingInfCriterion() throws Exception {
        boolean foundAttribute;
        double criterion = 0.0;
        double bestCriterion = Double.MAX_VALUE;
        int bestIteration = 0;
        int noMin = 0;
        double criterionValue = Double.MAX_VALUE;
        double[][] trainYs = this.getYs(this.m_train);
        double[][] trainFs = this.getFs(this.m_numericData);
        double[][] probs = this.getProbs(trainFs);
        boolean[][] attributes = new boolean[this.m_numClasses][this.m_numericDataHeader.numAttributes()];
        int iteration = 0;
        while (iteration < this.m_maxIterations && (foundAttribute = this.performIteration(iteration, trainYs, trainFs, probs, this.m_numericData))) {
            this.m_numRegressions = ++iteration;
            double numberOfAttributes = this.m_numParameters + (double)iteration;
            criterionValue = 2.0 * this.negativeLogLikelihood(trainYs, probs) + 2.0 * numberOfAttributes;
            if (noMin > this.m_heuristicStop) break;
            if (criterionValue < bestCriterion) {
                bestCriterion = criterionValue;
                bestIteration = iteration;
                noMin = 0;
                continue;
            }
            ++noMin;
        }
        this.m_numRegressions = 0;
        this.performBoosting(bestIteration);
    }

    protected int performBoosting(Instances train, Instances test, double[] error, int maxIterations) throws Exception {
        boolean foundAttribute;
        Instances numericTrain = this.getNumericData(train);
        double[][] trainYs = this.getYs(train);
        double[][] trainFs = this.getFs(numericTrain);
        double[][] probs = this.getProbs(trainFs);
        int iteration = 0;
        int noMin = 0;
        double lastMin = Double.MAX_VALUE;
        error[0] = this.m_errorOnProbabilities ? error[0] + this.getMeanAbsoluteError(test) : error[0] + this.getErrorRate(test);
        while (iteration < maxIterations && (foundAttribute = this.performIteration(iteration, trainYs, trainFs, probs, numericTrain))) {
            this.m_numRegressions = ++iteration;
            if (this.m_errorOnProbabilities) {
                int n = iteration;
                error[n] = error[n] + this.getMeanAbsoluteError(test);
            } else {
                int n = iteration;
                error[n] = error[n] + this.getErrorRate(test);
            }
            if (noMin > this.m_heuristicStop) break;
            if (error[iteration] < lastMin) {
                lastMin = error[iteration];
                noMin = 0;
                continue;
            }
            ++noMin;
        }
        return iteration;
    }

    protected void performBoosting(int numIterations) throws Exception {
        boolean foundAttribute;
        int iteration;
        double[][] trainYs = this.getYs(this.m_train);
        double[][] trainFs = this.getFs(this.m_numericData);
        double[][] probs = this.getProbs(trainFs);
        for (iteration = 0; iteration < numIterations && (foundAttribute = this.performIteration(iteration, trainYs, trainFs, probs, this.m_numericData)); ++iteration) {
        }
        this.m_numRegressions = iteration;
    }

    protected void performBoosting() throws Exception {
        boolean foundAttribute;
        double[][] trainYs = this.getYs(this.m_train);
        double[][] trainFs = this.getFs(this.m_numericData);
        double[][] probs = this.getProbs(trainFs);
        int iteration = 0;
        double[] trainErrors = new double[this.m_maxIterations + 1];
        trainErrors[0] = this.getErrorRate(this.m_train);
        int noMin = 0;
        double lastMin = Double.MAX_VALUE;
        while (iteration < this.m_maxIterations && (foundAttribute = this.performIteration(iteration, trainYs, trainFs, probs, this.m_numericData))) {
            this.m_numRegressions = ++iteration;
            trainErrors[iteration] = this.getErrorRate(this.m_train);
            if (noMin > this.m_heuristicStop) break;
            if (trainErrors[iteration] < lastMin) {
                lastMin = trainErrors[iteration];
                noMin = 0;
                continue;
            }
            ++noMin;
        }
        this.m_numRegressions = this.getBestIteration(trainErrors, iteration);
    }

    protected double getErrorRate(Instances data) throws Exception {
        Evaluation eval = new Evaluation(data);
        eval.evaluateModel(this, data, new Object[0]);
        return eval.errorRate();
    }

    protected double getMeanAbsoluteError(Instances data) throws Exception {
        Evaluation eval = new Evaluation(data);
        eval.evaluateModel(this, data, new Object[0]);
        return eval.meanAbsoluteError();
    }

    protected int getBestIteration(double[] errors, int maxIteration) {
        double bestError = errors[0];
        int bestIteration = 0;
        for (int i = 1; i <= maxIteration; ++i) {
            if (!(errors[i] < bestError)) continue;
            bestError = errors[i];
            bestIteration = i;
        }
        return bestIteration;
    }

    protected boolean performIteration(int iteration, double[][] trainYs, double[][] trainFs, double[][] probs, Instances trainNumeric) throws Exception {
        int i;
        for (int j = 0; j < this.m_numClasses; ++j) {
            double[] weights = new double[trainNumeric.numInstances()];
            double weightSum = 0.0;
            Instances boostData = new Instances(trainNumeric);
            for (int i2 = 0; i2 < trainNumeric.numInstances(); ++i2) {
                double p = probs[i2][j];
                double actual = trainYs[i2][j];
                double z = this.getZ(actual, p);
                double w = (actual - p) / z;
                Instance current = boostData.instance(i2);
                current.setValue(boostData.classIndex(), z);
                current.setWeight(current.weight() * w);
                weights[i2] = current.weight();
                weightSum += current.weight();
            }
            Instances instancesCopy = new Instances(boostData);
            if (weightSum > 0.0) {
                if (this.m_weightTrimBeta > 0.0) {
                    double weightPercentage = 0.0;
                    int[] weightsOrder = new int[trainNumeric.numInstances()];
                    weightsOrder = Utils.sort(weights);
                    instancesCopy.delete();
                    for (int i3 = weightsOrder.length - 1; i3 >= 0 && weightPercentage < 1.0 - this.m_weightTrimBeta; weightPercentage += weights[weightsOrder[i3]] / weightSum, --i3) {
                        instancesCopy.add(boostData.instance(weightsOrder[i3]));
                    }
                }
                weightSum = instancesCopy.sumOfWeights();
                for (int i4 = 0; i4 < instancesCopy.numInstances(); ++i4) {
                    Instance current = instancesCopy.instance(i4);
                    current.setWeight(current.weight() * (double)instancesCopy.numInstances() / weightSum);
                }
            }
            this.m_regressions[j][iteration].buildClassifier(instancesCopy);
            boolean foundAttribute = this.m_regressions[j][iteration].foundUsefulAttribute();
            if (foundAttribute) continue;
            return false;
        }
        for (i = 0; i < trainFs.length; ++i) {
            int j;
            double[] pred = new double[this.m_numClasses];
            double predSum = 0.0;
            for (j = 0; j < this.m_numClasses; ++j) {
                pred[j] = this.m_regressions[j][iteration].classifyInstance(trainNumeric.instance(i));
                predSum += pred[j];
            }
            predSum /= (double)this.m_numClasses;
            for (j = 0; j < this.m_numClasses; ++j) {
                double[] dArray = trainFs[i];
                int n = j;
                dArray[n] = dArray[n] + (pred[j] - predSum) * (double)(this.m_numClasses - 1) / (double)this.m_numClasses;
            }
        }
        for (i = 0; i < trainYs.length; ++i) {
            probs[i] = this.probs(trainFs[i]);
        }
        return true;
    }

    protected SimpleLinearRegression[][] initRegressions() {
        SimpleLinearRegression[][] classifiers = new SimpleLinearRegression[this.m_numClasses][this.m_maxIterations];
        for (int j = 0; j < this.m_numClasses; ++j) {
            for (int i = 0; i < this.m_maxIterations; ++i) {
                classifiers[j][i] = new SimpleLinearRegression();
                classifiers[j][i].setSuppressErrorMessage(true);
            }
        }
        return classifiers;
    }

    protected Instances getNumericData(Instances data) throws Exception {
        Instances numericData = new Instances(data);
        int classIndex = numericData.classIndex();
        numericData.setClassIndex(-1);
        numericData.deleteAttributeAt(classIndex);
        numericData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
        numericData.setClassIndex(classIndex);
        return numericData;
    }

    protected SimpleLinearRegression[][] selectRegressions(SimpleLinearRegression[][] classifiers) {
        SimpleLinearRegression[][] goodClassifiers = new SimpleLinearRegression[this.m_numClasses][this.m_numRegressions];
        for (int j = 0; j < this.m_numClasses; ++j) {
            for (int i = 0; i < this.m_numRegressions; ++i) {
                goodClassifiers[j][i] = classifiers[j][i];
            }
        }
        return goodClassifiers;
    }

    protected double getZ(double actual, double p) {
        double z;
        if (actual == 1.0) {
            z = 1.0 / p;
            if (z > 3.0) {
                z = 3.0;
            }
        } else {
            z = -1.0 / (1.0 - p);
            if (z < -3.0) {
                z = -3.0;
            }
        }
        return z;
    }

    protected double[][] getZs(double[][] probs, double[][] dataYs) {
        double[][] dataZs = new double[probs.length][this.m_numClasses];
        for (int j = 0; j < this.m_numClasses; ++j) {
            for (int i = 0; i < probs.length; ++i) {
                dataZs[i][j] = this.getZ(dataYs[i][j], probs[i][j]);
            }
        }
        return dataZs;
    }

    protected double[][] getWs(double[][] probs, double[][] dataYs) {
        double[][] dataWs = new double[probs.length][this.m_numClasses];
        for (int j = 0; j < this.m_numClasses; ++j) {
            for (int i = 0; i < probs.length; ++i) {
                double z = this.getZ(dataYs[i][j], probs[i][j]);
                dataWs[i][j] = (dataYs[i][j] - probs[i][j]) / z;
            }
        }
        return dataWs;
    }

    protected double[] probs(double[] Fs) {
        double maxF = -1.7976931348623157E308;
        for (int i = 0; i < Fs.length; ++i) {
            if (!(Fs[i] > maxF)) continue;
            maxF = Fs[i];
        }
        double sum = 0.0;
        double[] probs = new double[Fs.length];
        for (int i = 0; i < Fs.length; ++i) {
            probs[i] = Math.exp(Fs[i] - maxF);
            sum += probs[i];
        }
        Utils.normalize(probs, sum);
        return probs;
    }

    protected double[][] getYs(Instances data) {
        double[][] dataYs = new double[data.numInstances()][this.m_numClasses];
        for (int j = 0; j < this.m_numClasses; ++j) {
            for (int k = 0; k < data.numInstances(); ++k) {
                dataYs[k][j] = data.instance(k).classValue() == (double)j ? 1.0 : 0.0;
            }
        }
        return dataYs;
    }

    protected double[] getFs(Instance instance) throws Exception {
        double[] pred = new double[this.m_numClasses];
        double[] instanceFs = new double[this.m_numClasses];
        for (int i = 0; i < this.m_numRegressions; ++i) {
            int j;
            double predSum = 0.0;
            for (j = 0; j < this.m_numClasses; ++j) {
                pred[j] = this.m_regressions[j][i].classifyInstance(instance);
                predSum += pred[j];
            }
            predSum /= (double)this.m_numClasses;
            for (j = 0; j < this.m_numClasses; ++j) {
                int n = j;
                instanceFs[n] = instanceFs[n] + (pred[j] - predSum) * (double)(this.m_numClasses - 1) / (double)this.m_numClasses;
            }
        }
        return instanceFs;
    }

    protected double[][] getFs(Instances data) throws Exception {
        double[][] dataFs = new double[data.numInstances()][];
        for (int k = 0; k < data.numInstances(); ++k) {
            dataFs[k] = this.getFs(data.instance(k));
        }
        return dataFs;
    }

    protected double[][] getProbs(double[][] dataFs) {
        int numInstances = dataFs.length;
        double[][] probs = new double[numInstances][];
        for (int k = 0; k < numInstances; ++k) {
            probs[k] = this.probs(dataFs[k]);
        }
        return probs;
    }

    protected double negativeLogLikelihood(double[][] dataYs, double[][] probs) {
        double logLikelihood = 0.0;
        for (int i = 0; i < dataYs.length; ++i) {
            for (int j = 0; j < this.m_numClasses; ++j) {
                if (dataYs[i][j] != 1.0) continue;
                logLikelihood -= Math.log(probs[i][j]);
            }
        }
        return logLikelihood;
    }

    public int[][] getUsedAttributes() {
        int[][] usedAttributes = new int[this.m_numClasses][];
        double[][] coefficients = this.getCoefficients();
        for (int j = 0; j < this.m_numClasses; ++j) {
            boolean[] attributes = new boolean[this.m_numericDataHeader.numAttributes()];
            for (int i = 0; i < attributes.length; ++i) {
                if (Utils.eq(coefficients[j][i + 1], 0.0)) continue;
                attributes[i] = true;
            }
            int numAttributes = 0;
            for (int i = 0; i < this.m_numericDataHeader.numAttributes(); ++i) {
                if (!attributes[i]) continue;
                ++numAttributes;
            }
            int[] usedAttributesClass = new int[numAttributes];
            int count = 0;
            for (int i = 0; i < this.m_numericDataHeader.numAttributes(); ++i) {
                if (!attributes[i]) continue;
                usedAttributesClass[count] = i;
                ++count;
            }
            usedAttributes[j] = usedAttributesClass;
        }
        return usedAttributes;
    }

    public int getNumRegressions() {
        return this.m_numRegressions;
    }

    public double getWeightTrimBeta() {
        return this.m_weightTrimBeta;
    }

    public boolean getUseAIC() {
        return this.m_useAIC;
    }

    public void setMaxIterations(int maxIterations) {
        this.m_maxIterations = maxIterations;
    }

    public void setHeuristicStop(int heuristicStop) {
        this.m_heuristicStop = heuristicStop;
    }

    public void setWeightTrimBeta(double w) {
        this.m_weightTrimBeta = w;
    }

    public void setUseAIC(boolean c) {
        this.m_useAIC = c;
    }

    public int getMaxIterations() {
        return this.m_maxIterations;
    }

    protected double[][] getCoefficients() {
        int i;
        int j;
        double[][] coefficients = new double[this.m_numClasses][this.m_numericDataHeader.numAttributes() + 1];
        for (j = 0; j < this.m_numClasses; ++j) {
            for (i = 0; i < this.m_numRegressions; ++i) {
                double slope = this.m_regressions[j][i].getSlope();
                double intercept = this.m_regressions[j][i].getIntercept();
                int attribute = this.m_regressions[j][i].getAttributeIndex();
                double[] dArray = coefficients[j];
                dArray[0] = dArray[0] + intercept;
                double[] dArray2 = coefficients[j];
                int n = attribute + 1;
                dArray2[n] = dArray2[n] + slope;
            }
        }
        for (j = 0; j < coefficients.length; ++j) {
            i = 0;
            while (i < coefficients[0].length) {
                double[] dArray = coefficients[j];
                int n = i++;
                dArray[n] = dArray[n] * ((double)(this.m_numClasses - 1) / (double)this.m_numClasses);
            }
        }
        return coefficients;
    }

    public double percentAttributesUsed() {
        boolean[] attributes = new boolean[this.m_numericDataHeader.numAttributes()];
        double[][] coefficients = this.getCoefficients();
        for (int j = 0; j < this.m_numClasses; ++j) {
            for (int i = 1; i < this.m_numericDataHeader.numAttributes() + 1; ++i) {
                if (Utils.eq(coefficients[j][i], 0.0)) continue;
                attributes[i - 1] = true;
            }
        }
        double count = 0.0;
        for (int i = 0; i < attributes.length; ++i) {
            if (!attributes[i]) continue;
            count += 1.0;
        }
        return count / (double)(this.m_numericDataHeader.numAttributes() - 1) * 100.0;
    }

    public String toString() {
        StringBuffer s = new StringBuffer();
        int[][] attributes = this.getUsedAttributes();
        double[][] coefficients = this.getCoefficients();
        for (int j = 0; j < this.m_numClasses; ++j) {
            s.append("\nClass " + j + " :\n");
            s.append(Utils.doubleToString(coefficients[j][0], 4, 2) + " + \n");
            for (int i = 0; i < attributes[j].length; ++i) {
                s.append("[" + this.m_numericDataHeader.attribute(attributes[j][i]).name() + "]");
                s.append(" * " + Utils.doubleToString(coefficients[j][attributes[j][i] + 1], 4, 2));
                if (i != attributes[j].length - 1) {
                    s.append(" +");
                }
                s.append("\n");
            }
        }
        return new String(s);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        instance = (Instance)instance.copy();
        instance.setDataset(this.m_numericDataHeader);
        return this.probs(this.getFs(instance));
    }

    public void cleanup() {
        this.m_train = new Instances(this.m_train, 0);
        this.m_numericData = null;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.9 $");
    }
}

