package it.uniroma2.sag.kelp.examples.demo.qc;

import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning;
import it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive.LinearPassiveAggressiveClassification;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier;
import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper;
import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator;
import org.slf4j.impl.SimpleLogger;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/demo/qc/QuestionClassificationIncrementalLearning.class */
public class QuestionClassificationIncrementalLearning {
    public static void main(String[] strArr) {
        try {
            System.setProperty(SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "INFO");
            SimpleDataset simpleDataset = new SimpleDataset();
            simpleDataset.populate("src/main/resources/qc/train_5500.coarse.klp.gz");
            SimpleDataset simpleDataset2 = new SimpleDataset();
            simpleDataset2.populate("src/main/resources/qc/TREC_10.coarse.klp.gz");
            System.out.println("Training set statistics");
            System.out.print("Examples number ");
            System.out.println(simpleDataset.getNumberOfExamples());
            JacksonSerializerWrapper jacksonSerializerWrapper = new JacksonSerializerWrapper();
            LinearPassiveAggressiveClassification linearPassiveAggressiveClassification = new LinearPassiveAggressiveClassification();
            linearPassiveAggressiveClassification.setRepresentation("bow");
            linearPassiveAggressiveClassification.setC(3.0f);
            linearPassiveAggressiveClassification.setFairness(true);
            OneVsAllLearning oneVsAllLearning = new OneVsAllLearning();
            oneVsAllLearning.setBaseAlgorithm(linearPassiveAggressiveClassification);
            oneVsAllLearning.setLabels(simpleDataset.getClassificationLabels());
            String writeValueAsString = jacksonSerializerWrapper.writeValueAsString(oneVsAllLearning);
            System.out.println(writeValueAsString);
            Dataset[] splitClassDistributionInvariant = simpleDataset.splitClassDistributionInvariant(0.05f);
            splitClassDistributionInvariant[0] = splitClassDistributionInvariant[0].getShuffledDataset();
            splitClassDistributionInvariant[1] = splitClassDistributionInvariant[1].getShuffledDataset();
            oneVsAllLearning.learn(splitClassDistributionInvariant[0]);
            OneVsAllClassifier predictionFunction = oneVsAllLearning.getPredictionFunction();
            MulticlassClassificationEvaluator multiclassClassificationEvaluator = new MulticlassClassificationEvaluator(simpleDataset.getClassificationLabels());
            for (Example example : simpleDataset2.getExamples()) {
                multiclassClassificationEvaluator.addCount(example, predictionFunction.predict(example));
            }
            System.out.println("Accuracy after first part training: " + multiclassClassificationEvaluator.getAccuracy());
            Classifier classifier = (Classifier) jacksonSerializerWrapper.readValue(jacksonSerializerWrapper.writeValueAsString(predictionFunction), Classifier.class);
            oneVsAllLearning.learn(splitClassDistributionInvariant[1]);
            multiclassClassificationEvaluator.clear();
            for (Example example2 : simpleDataset2.getExamples()) {
                multiclassClassificationEvaluator.addCount(example2, predictionFunction.predict(example2));
            }
            System.out.println("Accuracy after second part training: " + multiclassClassificationEvaluator.getAccuracy());
            OneVsAllLearning oneVsAllLearning2 = (OneVsAllLearning) jacksonSerializerWrapper.readValue(writeValueAsString, OneVsAllLearning.class);
            oneVsAllLearning2.setPredictionFunction(classifier);
            MulticlassClassificationEvaluator multiclassClassificationEvaluator2 = new MulticlassClassificationEvaluator(simpleDataset.getClassificationLabels());
            for (Example example3 : simpleDataset2.getExamples()) {
                multiclassClassificationEvaluator2.addCount(example3, classifier.predict(example3));
            }
            System.out.println("Accuracy model from JSON after first part training: " + multiclassClassificationEvaluator2.getAccuracy());
            oneVsAllLearning2.learn(splitClassDistributionInvariant[1]);
            multiclassClassificationEvaluator2.clear();
            for (Example example4 : simpleDataset2.getExamples()) {
                multiclassClassificationEvaluator2.addCount(example4, classifier.predict(example4));
            }
            System.out.println("Accuracy model from JSON after second part training: " + multiclassClassificationEvaluator2.getAccuracy());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
