package it.uniroma2.sag.kelp.examples.demo.tweetsent2013;

import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.kernel.Kernel;
import it.uniroma2.sag.kelp.kernel.cache.FixIndexSquaredNormCache;
import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache;
import it.uniroma2.sag.kelp.kernel.standard.LinearKernelCombination;
import it.uniroma2.sag.kelp.kernel.standard.NormalizationKernel;
import it.uniroma2.sag.kelp.kernel.standard.PolynomialKernel;
import it.uniroma2.sag.kelp.kernel.standard.RbfKernel;
import it.uniroma2.sag.kelp.kernel.vector.LinearKernel;
import it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm.BinaryCSvmClassification;
import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning;
import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput;
import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier;
import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper;
import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator;
import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.CharEncoding;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/demo/tweetsent2013/TweetSentimentAnalysisSemeval2013.class */
public class TweetSentimentAnalysisSemeval2013 {
    private static String FIELD_SEP = "\t";
    private static String errors_file = "src/main/resources/tweetSentiment2013/errors.txt";

    public static void main(String[] strArr) throws Exception {
        Kernel bowKernel;
        float[] fArr = {0.1f, 0.5f, 1.0f};
        SimpleDataset simpleDataset = new SimpleDataset();
        simpleDataset.populate("src/main/resources/tweetSentiment2013/train.klp.gz");
        SimpleDataset simpleDataset2 = new SimpleDataset();
        simpleDataset2.populate("src/main/resources/tweetSentiment2013/test.klp.gz");
        int numberOfExamples = simpleDataset.getNumberOfExamples() + simpleDataset2.getNumberOfExamples();
        switch (1) {
            case 1:
                bowKernel = getBowKernel(numberOfExamples);
                break;
            case 2:
                bowKernel = getPolyBow(numberOfExamples, 0.0f);
                break;
            case 3:
                bowKernel = getWordspaceKernel(numberOfExamples);
                break;
            case 4:
                bowKernel = getRbfWordspaceKernel(numberOfExamples, 0.0f);
                break;
            case 5:
                bowKernel = getBowWordSpaceKernel(numberOfExamples);
                break;
            case 6:
                bowKernel = getPolyBowRbfWordspaceKernel(numberOfExamples, 0.0f, 0.0f);
                break;
            default:
                bowKernel = getBowKernel(numberOfExamples);
                break;
        }
        float tune = tune(simpleDataset, bowKernel, 0.8f, fArr);
        System.out.println("start testing with C=" + tune);
        System.out.println("Mean F1 on test set=" + test(simpleDataset, bowKernel, tune, simpleDataset2, true));
    }

    private static float test(SimpleDataset simpleDataset, Kernel kernel, float f, SimpleDataset simpleDataset2, boolean z) throws NoSuchPerformanceMeasureException, IOException {
        ArrayList arrayList = (ArrayList) simpleDataset.getClassificationLabels();
        BinaryCSvmClassification binaryCSvmClassification = new BinaryCSvmClassification();
        binaryCSvmClassification.setKernel(kernel);
        binaryCSvmClassification.setCp(f);
        binaryCSvmClassification.setCn(f);
        binaryCSvmClassification.setFairness(true);
        OneVsAllLearning oneVsAllLearning = new OneVsAllLearning();
        oneVsAllLearning.setBaseAlgorithm(binaryCSvmClassification);
        oneVsAllLearning.setLabels(arrayList);
        oneVsAllLearning.learn(simpleDataset);
        OneVsAllClassifier predictionFunction = oneVsAllLearning.getPredictionFunction();
        JacksonSerializerWrapper jacksonSerializerWrapper = new JacksonSerializerWrapper();
        jacksonSerializerWrapper.writeValueOnFile(oneVsAllLearning, "src/main/resources/tweetSentiment2013/learningAlgorithmSpecification_multi.klp");
        jacksonSerializerWrapper.writeValueOnFile(predictionFunction, "src/main/resources/tweetSentiment2013/classificationAlgorithm_bow_ws.klp");
        MulticlassClassificationEvaluator multiclassClassificationEvaluator = new MulticlassClassificationEvaluator(arrayList);
        PrintStream printStream = z ? new PrintStream(errors_file, CharEncoding.UTF_8) : null;
        for (Example example : simpleDataset2.getExamples()) {
            OneVsAllClassificationOutput predict = predictionFunction.predict(example);
            Label label = example.getLabels()[0];
            Label label2 = predict.getPredictedClasses().get(0);
            if (z) {
                printStream.println(label + "\t" + label2 + "\t" + (label.equals(label2) ? "1" : "0"));
            }
            multiclassClassificationEvaluator.addCount(example, predict);
        }
        if (z) {
            printStream.flush();
            printStream.close();
        }
        Label findLabel = findLabel("neutral", arrayList);
        Label findLabel2 = findLabel("positive", arrayList);
        Label findLabel3 = findLabel("negative", arrayList);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(findLabel2);
        arrayList2.add(findLabel3);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(findLabel2);
        arrayList3.add(findLabel3);
        arrayList3.add(findLabel);
        StringBuilder sb = new StringBuilder();
        Iterator it2 = arrayList3.iterator();
        while (it2.hasNext()) {
            sb.append(FIELD_SEP + ((Label) it2.next()) + FIELD_SEP);
        }
        sb.append("\n");
        sb.append("Precision" + FIELD_SEP + "Recall" + FIELD_SEP + "F1" + FIELD_SEP);
        sb.append("Precision" + FIELD_SEP + "Recall" + FIELD_SEP + "F1" + FIELD_SEP);
        sb.append("Precision" + FIELD_SEP + "Recall" + FIELD_SEP + "F1" + FIELD_SEP + "F1-Pn" + FIELD_SEP + "F1-Pnn\n");
        Iterator it3 = arrayList3.iterator();
        while (it3.hasNext()) {
            Label label3 = (Label) it3.next();
            sb.append(multiclassClassificationEvaluator.getPrecisionFor(label3) + FIELD_SEP + multiclassClassificationEvaluator.getRecallFor(label3) + FIELD_SEP + multiclassClassificationEvaluator.getF1For(label3) + FIELD_SEP);
        }
        sb.append(multiclassClassificationEvaluator.getPerformanceMeasure("MeanF1For", arrayList2) + FIELD_SEP);
        sb.append(multiclassClassificationEvaluator.getPerformanceMeasure("MeanF1", new Object[0]));
        System.out.println(sb.toString());
        return multiclassClassificationEvaluator.getMeanF1();
    }

    private static Label findLabel(String str, List<Label> list) {
        for (Label label : list) {
            if (label.toString().equalsIgnoreCase(str)) {
                return label;
            }
        }
        return null;
    }

    private static float tune(SimpleDataset simpleDataset, Kernel kernel, float f, float[] fArr) throws NoSuchPerformanceMeasureException, IOException {
        float f2 = 0.0f;
        float f3 = -3.4028235E38f;
        SimpleDataset[] splitClassDistributionInvariant = simpleDataset.splitClassDistributionInvariant(f);
        SimpleDataset simpleDataset2 = splitClassDistributionInvariant[0];
        SimpleDataset simpleDataset3 = splitClassDistributionInvariant[1];
        for (float f4 : fArr) {
            float test = test(simpleDataset2, kernel, f4, simpleDataset3, false);
            System.out.println("C:" + f4 + "\t" + test);
            if (test > f3) {
                f3 = test;
                f2 = f4;
            }
        }
        return f2;
    }

    private static Kernel getBowKernel(int i) {
        LinearKernel linearKernel = new LinearKernel("BOW");
        linearKernel.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel = new NormalizationKernel(linearKernel);
        normalizationKernel.setKernelCache(new FixSizeKernelCache(i));
        return normalizationKernel;
    }

    private static Kernel getPolyBow(int i, float f) {
        PolynomialKernel polynomialKernel = new PolynomialKernel(f, new LinearKernel("BOW"));
        polynomialKernel.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel = new NormalizationKernel(polynomialKernel);
        normalizationKernel.setKernelCache(new FixSizeKernelCache(i));
        return normalizationKernel;
    }

    private static Kernel getWordspaceKernel(int i) {
        LinearKernel linearKernel = new LinearKernel("WS");
        linearKernel.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel = new NormalizationKernel(linearKernel);
        normalizationKernel.setKernelCache(new FixSizeKernelCache(i));
        return normalizationKernel;
    }

    private static Kernel getRbfWordspaceKernel(int i, float f) {
        RbfKernel rbfKernel = new RbfKernel(f, new LinearKernel("WS"));
        rbfKernel.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel = new NormalizationKernel(rbfKernel);
        normalizationKernel.setKernelCache(new FixSizeKernelCache(i));
        return normalizationKernel;
    }

    private static Kernel getBowWordSpaceKernel(int i) {
        LinearKernel linearKernel = new LinearKernel("BOW");
        linearKernel.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel = new NormalizationKernel(linearKernel);
        LinearKernel linearKernel2 = new LinearKernel("WS");
        linearKernel2.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel2 = new NormalizationKernel(linearKernel2);
        LinearKernelCombination linearKernelCombination = new LinearKernelCombination();
        linearKernelCombination.addKernel(1.0f, normalizationKernel);
        linearKernelCombination.addKernel(1.0f, normalizationKernel2);
        linearKernelCombination.setKernelCache(new FixSizeKernelCache(i));
        return linearKernelCombination;
    }

    private static Kernel getPolyBowRbfWordspaceKernel(int i, float f, float f2) {
        PolynomialKernel polynomialKernel = new PolynomialKernel(f, new LinearKernel("BOW"));
        polynomialKernel.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel = new NormalizationKernel(polynomialKernel);
        RbfKernel rbfKernel = new RbfKernel(f2, new LinearKernel("WS"));
        rbfKernel.setSquaredNormCache(new FixIndexSquaredNormCache(i));
        NormalizationKernel normalizationKernel2 = new NormalizationKernel(rbfKernel);
        LinearKernelCombination linearKernelCombination = new LinearKernelCombination();
        linearKernelCombination.addKernel(1.0f, normalizationKernel);
        linearKernelCombination.addKernel(1.0f, normalizationKernel2);
        linearKernelCombination.setKernelCache(new FixSizeKernelCache(i));
        return linearKernelCombination;
    }
}
