package it.uniroma2.sag.kelp.utils.evaluation;

import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.predictionfunction.Prediction;
import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:it/uniroma2/sag/kelp/utils/evaluation/MulticlassClassificationEvaluator.class */
public class MulticlassClassificationEvaluator extends Evaluator {
    protected List<Label> labels;
    protected HashMap<Label, ClassStats> classStats = new HashMap<>();
    protected int total;
    protected int correct;
    protected int totalTp;
    protected int totalTn;
    protected int totalFp;
    protected int totalFn;
    private float accuracy;
    private float overallPrecision;
    private float overallRecall;
    private float overallF1;
    private float microPrecision;
    private float microRecall;
    private float microF1;
    private float macroPrecision;
    private float macroRecall;
    private float macroF1;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:it/uniroma2/sag/kelp/utils/evaluation/MulticlassClassificationEvaluator$ClassStats.class */
    public class ClassStats {
        protected int tp;
        protected int fp;
        protected int tn;
        protected int fn;
        protected float precision;
        protected float recall;
        protected float f1;

        protected ClassStats() {
        }
    }

    public MulticlassClassificationEvaluator(List<Label> list) {
        this.labels = list;
        initializeCounters();
    }

    private void initializeCounters() {
        Iterator<Label> it2 = this.labels.iterator();
        while (it2.hasNext()) {
            this.classStats.put(it2.next(), new ClassStats());
        }
        this.accuracy = 0.0f;
        this.computed = false;
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public void addCount(Example example, Prediction prediction) {
        ClassificationOutput classificationOutput = (ClassificationOutput) prediction;
        boolean z = true;
        for (Label label : this.labels) {
            ClassStats classStats = this.classStats.get(label);
            if (example.isExampleOf(label)) {
                if (classificationOutput.isClassPredicted(label)) {
                    classStats.tp++;
                    this.totalTp++;
                } else {
                    classStats.fn++;
                    this.totalFn++;
                    z = false;
                }
            } else if (classificationOutput.isClassPredicted(label)) {
                classStats.fp++;
                this.totalFp++;
                z = false;
            } else {
                this.totalTn++;
                classStats.tn++;
            }
        }
        this.total++;
        if (z) {
            this.correct++;
        }
        this.computed = false;
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    protected void compute() {
        if (this.total > 0) {
            this.accuracy = this.correct / this.total;
        }
        this.macroPrecision = 0.0f;
        Iterator<Label> it2 = this.labels.iterator();
        while (it2.hasNext()) {
            ClassStats classStats = this.classStats.get(it2.next());
            if (classStats.tp + classStats.fp == 0) {
                classStats.precision = 1.0f;
            } else {
                classStats.precision = classStats.tp / (classStats.tp + classStats.fp);
            }
            this.macroPrecision += classStats.precision;
            if (classStats.tp + classStats.fn == 0) {
                classStats.recall = 0.0f;
            } else {
                classStats.recall = classStats.tp / (classStats.tp + classStats.fn);
            }
            this.macroRecall += classStats.recall;
            if (classStats.precision == 0.0f || classStats.recall == 0.0f) {
                classStats.f1 = 0.0f;
            } else {
                classStats.f1 = ((2.0f * classStats.precision) * classStats.recall) / (classStats.precision + classStats.recall);
            }
        }
        this.macroPrecision /= this.labels.size();
        this.macroRecall /= this.labels.size();
        if (this.macroPrecision == 0.0f || this.macroRecall == 0.0f) {
            this.macroF1 = 0.0f;
        } else {
            this.macroF1 = ((2.0f * this.macroPrecision) * this.macroRecall) / (this.macroPrecision + this.macroRecall);
        }
        if (this.totalTp + this.totalFp > 0) {
            this.microPrecision = this.totalTp / (this.totalTp + this.totalFp);
        } else {
            this.microPrecision = 0.0f;
        }
        if (this.totalTp + this.totalFn > 0) {
            this.microRecall = this.totalTp / (this.totalTp + this.totalFn);
        } else {
            this.microRecall = 0.0f;
        }
        if (this.microPrecision == 0.0f || this.microRecall == 0.0f) {
            this.microF1 = 0.0f;
        } else {
            this.microF1 = ((2.0f * this.microPrecision) * this.microRecall) / (this.microPrecision + this.microRecall);
        }
        this.computed = true;
    }

    public float getPrecisionFor(Label label) {
        if (!this.computed) {
            compute();
        }
        ClassStats classStats = this.classStats.get(label);
        if (classStats != null) {
            return classStats.precision;
        }
        return -1.0f;
    }

    public float getRecallFor(Label label) {
        if (!this.computed) {
            compute();
        }
        ClassStats classStats = this.classStats.get(label);
        if (classStats != null) {
            return classStats.recall;
        }
        return -1.0f;
    }

    public float getF1For(Label label) {
        if (!this.computed) {
            compute();
        }
        ClassStats classStats = this.classStats.get(label);
        if (classStats != null) {
            return classStats.f1;
        }
        return -1.0f;
    }

    public float getTpFor(Label label) {
        if (this.classStats.get(label) != null) {
            return r0.tp;
        }
        return -1.0f;
    }

    public float getTnFor(Label label) {
        if (this.classStats.get(label) != null) {
            return r0.tn;
        }
        return -1.0f;
    }

    public float getFpFor(Label label) {
        if (this.classStats.get(label) != null) {
            return r0.fp;
        }
        return -1.0f;
    }

    public float getFnFor(Label label) {
        if (this.classStats.get(label) != null) {
            return r0.fn;
        }
        return -1.0f;
    }

    public float getAccuracy() {
        if (!this.computed) {
            compute();
        }
        return this.accuracy;
    }

    public float getMicroPrecision() {
        if (!this.computed) {
            compute();
        }
        return this.microPrecision;
    }

    public float getMicroRecall() {
        if (!this.computed) {
            compute();
        }
        return this.microRecall;
    }

    public float getMicroF1() {
        if (!this.computed) {
            compute();
        }
        return this.microF1;
    }

    @Deprecated
    public float getOverallPrecision() {
        if (!this.computed) {
            compute();
        }
        return this.overallPrecision;
    }

    @Deprecated
    public float getOverallRecall() {
        if (!this.computed) {
            compute();
        }
        return this.overallRecall;
    }

    @Deprecated
    public float getOverallF1() {
        if (!this.computed) {
            compute();
        }
        return this.overallF1;
    }

    @Deprecated
    public float getMeanF1() {
        if (!this.computed) {
            compute();
        }
        return this.macroF1;
    }

    public float getMacroPrecision() {
        if (!this.computed) {
            compute();
        }
        return this.macroPrecision;
    }

    public float getMacroRecall() {
        if (!this.computed) {
            compute();
        }
        return this.macroRecall;
    }

    public float getMacroF1() {
        if (!this.computed) {
            compute();
        }
        return this.macroF1;
    }

    public float getMeanF1For(ArrayList<Label> arrayList) {
        if (!this.computed) {
            compute();
        }
        float f = 0.0f;
        Iterator<Label> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            ClassStats classStats = this.classStats.get(it2.next());
            if (classStats != null) {
                f += classStats.f1;
            }
        }
        return f / arrayList.size();
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public void clear() {
        for (ClassStats classStats : this.classStats.values()) {
            classStats.tp = 0;
            classStats.tn = 0;
            classStats.fp = 0;
            classStats.fn = 0;
        }
        this.total = 0;
        this.correct = 0;
        this.computed = false;
    }

    private void printCounters() {
        for (Label label : this.labels) {
            System.out.println(label);
            System.out.print("\t");
            printCounters(label);
        }
    }

    public void printCounters(Label label) {
        ClassStats classStats = this.classStats.get(label);
        if (classStats != null) {
            System.out.println("class " + label.toString() + ": tp=" + classStats.tp + " tn=" + classStats.tn + " fp=" + classStats.fp + " fn=" + classStats.fn);
        } else {
            System.out.println("There are no counters for the label " + label.toString());
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (Label label : this.labels) {
            sb.append(label + "\t" + getPrecisionFor(label) + "\t" + getRecallFor(label) + "\t" + getF1For(label) + "\n");
        }
        return sb.toString().trim();
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public MulticlassClassificationEvaluator duplicate() {
        return new MulticlassClassificationEvaluator(this.labels);
    }
}
