package it.uniroma2.sag.kelp.utils.evaluation;

import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.predictionfunction.Prediction;
import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionOutput;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:it/uniroma2/sag/kelp/utils/evaluation/RegressorEvaluator.class */
public class RegressorEvaluator extends Evaluator {
    private HashMap<Label, Float> values = new HashMap<>();
    private HashMap<Label, Float> errors = new HashMap<>();
    int n = 0;

    public RegressorEvaluator(List<Label> list) {
        for (Label label : list) {
            this.values.put(label, Float.valueOf(0.0f));
            this.errors.put(label, Float.valueOf(0.0f));
        }
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public void addCount(Example example, Prediction prediction) {
        for (Label label : this.values.keySet()) {
            Float regressionValue = example.getRegressionValue(label);
            Float score = ((UnivariateRegressionOutput) prediction).getScore(label);
            this.values.put(label, Float.valueOf(this.values.get(label).floatValue() + ((score.floatValue() - regressionValue.floatValue()) * (score.floatValue() - regressionValue.floatValue()))));
        }
        this.n++;
        this.computed = false;
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    protected void compute() {
        for (Label label : this.values.keySet()) {
            this.errors.put(label, Float.valueOf(this.values.get(label).floatValue() / this.n));
        }
        this.computed = true;
    }

    public float getMeanSquaredError(Label label) {
        if (!this.computed) {
            compute();
        }
        if (this.errors.containsKey(label)) {
            return this.errors.get(label).floatValue();
        }
        return -1.0f;
    }

    public float getMeanSquaredErrors() {
        if (!this.computed) {
            compute();
        }
        float f = 0.0f;
        Iterator<Label> it2 = this.errors.keySet().iterator();
        while (it2.hasNext()) {
            f += this.errors.get(it2.next()).floatValue();
        }
        return f / this.errors.keySet().size();
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public void clear() {
        this.values.clear();
        this.errors.clear();
        this.n = 0;
        this.computed = false;
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public RegressorEvaluator duplicate() {
        return new RegressorEvaluator(new ArrayList(this.values.keySet()));
    }
}
