package it.uniroma2.sag.kelp.linearization.nystrom;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.example.SimpleExample;
import it.uniroma2.sag.kelp.data.representation.vector.DenseVector;
import it.uniroma2.sag.kelp.kernel.Kernel;
import it.uniroma2.sag.kelp.linearization.LinearizationFunction;
import it.uniroma2.sag.kelp.utils.FileUtils;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.DecompositionFactory;
import org.ejml.factory.SingularValueDecomposition;
import org.ejml.ops.CommonOps;
import org.ejml.ops.SingularOps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonTypeName("nystrom")
/* loaded from: input_file:it/uniroma2/sag/kelp/linearization/nystrom/NystromMethod.class */
public class NystromMethod implements LinearizationFunction {
    private Logger logger;

    @JsonIgnore
    public static final float ESPILON = 1.0E-5f;
    private Kernel kernel;
    private List<Example> landmarks;
    private List<Double> projectionMatrix;

    @JsonIgnore
    private boolean debug;

    @JsonIgnore
    private DenseMatrix64F USigmaSquare;

    @JsonIgnore
    private DenseMatrix64F kernelValuesToProject;
    private int rank;

    public static NystromMethod load(String str) throws FileNotFoundException, IOException {
        ObjectMapper objectMapper = new ObjectMapper();
        InputStreamReader inputStreamReader = new InputStreamReader(FileUtils.createInputStream(str), "utf8");
        NystromMethod nystromMethod = (NystromMethod) objectMapper.readValue(inputStreamReader, NystromMethod.class);
        inputStreamReader.close();
        return nystromMethod;
    }

    public NystromMethod() {
        this.logger = LoggerFactory.getLogger(NystromMethod.class);
        this.debug = false;
    }

    public NystromMethod(List<Example> list, Kernel kernel) throws InstantiationException {
        this(list, kernel, list.size());
    }

    public NystromMethod(List<Example> list, Kernel kernel, int i) throws InstantiationException {
        this.logger = LoggerFactory.getLogger(NystromMethod.class);
        this.debug = false;
        this.kernel = kernel;
        this.landmarks = list;
        int size = this.landmarks.size();
        this.rank = i;
        if (i > size) {
            debug("Expected Rank (" + i + ") and it is higher than m (" + size + "). It will be reduced to m.");
            this.rank = size;
        }
        calculateProjMatrix();
    }

    private void calculateProjMatrix() {
        int size = this.landmarks.size();
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(size, size);
        info("Numbero of landmarks:\t" + size);
        info("Building W...");
        for (int i = 0; i < size; i++) {
            if ((i + 1) % 100 == 0) {
                info("Evaluated " + (i + 1) + " landmarks.");
            }
            for (int i2 = i; i2 < size; i2++) {
                float innerProduct = this.kernel.innerProduct(this.landmarks.get(i), this.landmarks.get(i2));
                denseMatrix64F.set(i, i2, innerProduct);
                denseMatrix64F.set(i2, i, innerProduct);
                if (i == i2) {
                    denseMatrix64F.set(i, i, denseMatrix64F.get(i, i));
                }
            }
        }
        debug("W\n" + denseMatrix64F);
        info("SVD Decomposition...");
        SingularValueDecomposition<DenseMatrix64F> svd = DecompositionFactory.svd(denseMatrix64F.getNumRows(), denseMatrix64F.getNumRows(), true, true, false);
        info("Decompostion completed");
        if (!svd.decompose(denseMatrix64F)) {
            throw new RuntimeException("Decomposition failed");
        }
        DenseMatrix64F u = svd.getU(null, false);
        DenseMatrix64F v = svd.getV(null, false);
        debug("U\n" + u);
        DenseMatrix64F w = svd.getW(null);
        debug("Sigma^(1/2)\n" + w);
        SingularOps.descendingOrder(u, false, w, v, false);
        for (int i3 = 0; i3 < size; i3++) {
            debug("Sigma\t" + i3 + "\t" + w.get(i3, i3));
        }
        int i4 = 0;
        while (true) {
            if (i4 >= this.rank) {
                break;
            }
            if (w.get(i4, i4) / w.get(0, 0) < 9.999999747378752E-6d) {
                this.rank = i4;
                break;
            }
            i4++;
        }
        info("Final matrix rank:\t" + this.rank);
        for (int i5 = this.rank; i5 < size; i5++) {
            w.set(i5, i5, CMAESOptimizer.DEFAULT_STOPFITNESS);
        }
        info("Calculating Projection matrix...");
        for (int i6 = 0; i6 < size; i6++) {
            double d = w.get(i6, i6);
            if (d > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                d = 1.0d / Math.sqrt(d);
            }
            w.set(i6, i6, d);
        }
        this.USigmaSquare = new DenseMatrix64F(size, size);
        CommonOps.mult(u, w, this.USigmaSquare);
        debug("U*Sigma^(1/2)\n" + this.USigmaSquare);
        this.projectionMatrix = new ArrayList();
        for (int i7 = 0; i7 < size; i7++) {
            for (int i8 = 0; i8 < this.rank; i8++) {
                this.projectionMatrix.add(Double.valueOf(this.USigmaSquare.get(i7, i8)));
            }
        }
        this.USigmaSquare = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] calculateVector(Example example) {
        return getDenseMatrix64F(example).data;
    }

    private DenseMatrix64F getDenseMatrix64F(Example example) {
        int size = this.landmarks.size();
        if (this.USigmaSquare == null) {
            this.USigmaSquare = new DenseMatrix64F(size, size);
            double[] dArr = new double[size * size];
            int i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < size; i3++) {
                for (int i4 = 0; i4 < this.rank; i4++) {
                    int i5 = i2;
                    i2++;
                    int i6 = i;
                    i++;
                    dArr[i5] = this.projectionMatrix.get(i6).doubleValue();
                }
                for (int i7 = this.rank; i7 < size; i7++) {
                    int i8 = i2;
                    i2++;
                    dArr[i8] = 0.0d;
                }
            }
            this.USigmaSquare.setData(dArr);
        }
        if (this.kernelValuesToProject == null) {
            this.kernelValuesToProject = new DenseMatrix64F(1, size);
        }
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(1, size);
        for (int i9 = 0; i9 < size; i9++) {
            this.kernelValuesToProject.set(0, i9, this.kernel.innerProduct(example, this.landmarks.get(i9)));
        }
        CommonOps.mult(this.kernelValuesToProject, this.USigmaSquare, denseMatrix64F);
        return denseMatrix64F;
    }

    private void debug(String str) {
        this.logger.debug(str);
    }

    public Kernel getKernel() {
        return this.kernel;
    }

    public List<Example> getLandmarks() {
        return this.landmarks;
    }

    @Override // it.uniroma2.sag.kelp.linearization.LinearizationFunction
    public SimpleDataset getLinearizedDataset(Dataset dataset, String str) {
        SimpleDataset simpleDataset = new SimpleDataset();
        int i = 1;
        Iterator<Example> it2 = dataset.getExamples().iterator();
        while (it2.hasNext()) {
            simpleDataset.addExample(getLinearizedExample(it2.next(), str));
            if (i % EmpiricalDistribution.DEFAULT_BIN_COUNT == 0) {
                info("Projected " + i + " examples.");
            }
            i++;
        }
        return simpleDataset;
    }

    @Override // it.uniroma2.sag.kelp.linearization.LinearizationFunction
    public Example getLinearizedExample(Example example, String str) {
        DenseVector denseVector = new DenseVector(getDenseMatrix64F(example));
        HashMap hashMap = new HashMap();
        hashMap.put(str, denseVector);
        SimpleExample simpleExample = new SimpleExample(example.getLabels(), hashMap);
        simpleExample.setRegressionValues(example.getRegressionValues());
        return simpleExample;
    }

    @Override // it.uniroma2.sag.kelp.linearization.LinearizationFunction
    public DenseVector getLinearRepresentation(Example example) {
        return new DenseVector(calculateVector(example));
    }

    public List<Double> getProjectionMatrix() {
        return this.projectionMatrix;
    }

    public int getRank() {
        return this.rank;
    }

    private void info(String str) {
        this.logger.info(str);
    }

    public void save(String str) throws FileNotFoundException, IOException {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
        OutputStreamWriter outputStreamWriter = new OutputStreamWriter(FileUtils.createOutputStream(str), "utf8");
        objectMapper.writeValue(outputStreamWriter, this);
        outputStreamWriter.close();
    }

    public void setKernel(Kernel kernel) {
        this.kernel = kernel;
    }

    public void setLandmarks(List<Example> list) {
        this.landmarks = list;
    }

    public void setProjectionMatrix(List<Double> list) {
        this.projectionMatrix = list;
    }

    public void setRank(int i) {
        this.rank = i;
    }

    @Override // it.uniroma2.sag.kelp.linearization.LinearizationFunction
    public int getEmbeddingSize() {
        return this.landmarks.size();
    }
}
