package it.uniroma2.sag.kelp.examples.demo.rcv1;

import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.StringLabel;
import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier;
import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper;
import it.uniroma2.sag.kelp.utils.Math;
import it.uniroma2.sag.kelp.utils.evaluation.BinaryClassificationEvaluator;
import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.UnsupportedEncodingException;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/demo/rcv1/RCV1BinaryTextCategorization.class */
public abstract class RCV1BinaryTextCategorization {
    private StringLabel positiveLabel = new StringLabel("1");
    protected String algoSuffix = "";

    /* JADX INFO: Access modifiers changed from: protected */
    public void foldLearn(float f, int i, SimpleDataset simpleDataset) {
        SimpleDataset[] nFoldingClassDistributionInvariant = simpleDataset.nFoldingClassDistributionInvariant(i);
        float[] fArr = new float[nFoldingClassDistributionInvariant.length];
        for (int i2 = 0; i2 < i; i2++) {
            SimpleDataset simpleDataset2 = nFoldingClassDistributionInvariant[i2];
            SimpleDataset allExcept = getAllExcept(nFoldingClassDistributionInvariant, i2);
            try {
                System.out.println("start testing with C=" + f);
                fArr[i2] = test(allExcept, f, simpleDataset2);
            } catch (NoSuchPerformanceMeasureException e) {
                e.printStackTrace();
            } catch (FileNotFoundException e2) {
                e2.printStackTrace();
            } catch (UnsupportedEncodingException e3) {
                e3.printStackTrace();
            } catch (IOException e4) {
                e4.printStackTrace();
            }
        }
        System.out.println("Accuracy mean/std on test set=" + Math.getMean(fArr) + "/" + Math.getStandardDeviation(fArr));
    }

    private float test(SimpleDataset simpleDataset, float f, SimpleDataset simpleDataset2) throws NoSuchPerformanceMeasureException, IOException {
        LearningAlgorithm learningAlgorithm = getLearningAlgorithm(f, "VEC", this.positiveLabel);
        learningAlgorithm.learn(simpleDataset);
        BinaryLinearClassifier binaryLinearClassifier = (BinaryLinearClassifier) learningAlgorithm.getPredictionFunction();
        JacksonSerializerWrapper jacksonSerializerWrapper = new JacksonSerializerWrapper();
        jacksonSerializerWrapper.writeValueOnFile(learningAlgorithm, "src/main/resources/rcv1/learningAlgorithmSpecification" + this.algoSuffix + ".klp");
        jacksonSerializerWrapper.writeValueOnFile(binaryLinearClassifier, "src/main/resources/rcv1/classificationAlgorithmSpecification" + this.algoSuffix + ".klp");
        BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator(this.positiveLabel);
        for (Example example : simpleDataset2.getExamples()) {
            binaryClassificationEvaluator.addCount(example, binaryLinearClassifier.predict(example));
        }
        return binaryClassificationEvaluator.getAccuracy();
    }

    protected abstract LearningAlgorithm getLearningAlgorithm(float f, String str, StringLabel stringLabel);

    private static SimpleDataset getAllExcept(Dataset[] datasetArr, int i) {
        SimpleDataset simpleDataset = new SimpleDataset();
        for (int i2 = 0; i2 < datasetArr.length; i2++) {
            if (i != i2) {
                simpleDataset.addExamples(datasetArr[i2]);
            }
        }
        return simpleDataset;
    }
}
