Computations.java
/*-
* #%L
* Strange
* %%
* Copyright (C) 2020 Johan Vos
* %%
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the Johan Vos nor the names of its contributors
* may be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
* INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
* OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
* OF THE POSSIBILITY OF SUCH DAMAGE.
* #L%
*/
package org.redfx.strange.local;
import org.redfx.strange.*;
import org.redfx.strange.gate.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.function.Consumer;
import static org.redfx.strange.Complex.tensor;
/**
* <p>Computations class.</p>
*
* @author johan
* @version $Id: $Id
*/
public class Computations {
private static final boolean debug = false;
static void dbg(String s) {
SimpleQuantumExecutionEnvironment.LOG.finer(s);
}
/**
* <p>calculateStepMatrix.</p>
*
* @param gates a {@link java.util.List} object
* @param nQubits a int
* @param qee a {@link org.redfx.strange.QuantumExecutionEnvironment} object
* @return an array of {@link org.redfx.strange.Complex} objects
*/
public static Complex[][] calculateStepMatrix(List<Gate> gates, int nQubits, QuantumExecutionEnvironment qee) {
long l0 = System.currentTimeMillis();
Complex[][] a = new Complex[1][1];
a[0][0] = Complex.ONE;
int idx = nQubits - 1;
while (idx >= 0) {
final int cnt = idx;
Gate myGate = gates.stream()
.filter(
// gate -> gate.getAffectedQubitIndex().contains(cnt)
gate -> gate.getHighestAffectedQubitIndex() == cnt)
.findFirst()
.orElse(new Identity(idx));
dbg("stepmatrix, cnt = " + cnt + ", idx = " + idx + ", myGate = " + myGate);
if (myGate instanceof BlockGate) {
dbg("calculateStepMatrix for blockgate " + myGate + " of class " + myGate.getClass());
BlockGate sqg = (BlockGate) myGate;
a = tensor(a, sqg.getMatrix(qee));
dbg("calculateStepMatrix for blockgate calculated " + myGate);
idx = idx - sqg.getSize() + 1;
}
if (myGate instanceof SingleQubitGate) {
SingleQubitGate sqg = (SingleQubitGate) myGate;
a = tensor(a, sqg.getMatrix());
}
if (myGate instanceof TwoQubitGate) {
TwoQubitGate tqg = (TwoQubitGate) myGate;
a = tensor(a, tqg.getMatrix());
idx--;
}
if (myGate instanceof ThreeQubitGate) {
ThreeQubitGate tqg = (ThreeQubitGate) myGate;
a = tensor(a, tqg.getMatrix());
idx = idx - 2;
}
if (myGate instanceof PermutationGate) {
throw new RuntimeException("No perm allowed ");
}
if (myGate instanceof Oracle) {
a = myGate.getMatrix();
idx = 0;
}
idx--;
}
long l1 = System.currentTimeMillis();
return a;
}
/**
* decompose a Step into steps that can be processed without permutations
*
* @param s a {@link org.redfx.strange.Step} object
* @param nqubit a int
* @return a {@link java.util.List} object
*/
public static List<Step> decomposeStep(Step s, int nqubit) {
ArrayList<Step> answer = new ArrayList<>();
answer.add(s);
if (s.getType() == Step.Type.PSEUDO) {
s.setComplexStep(s.getIndex());
return answer;
}
List<Gate> gates = s.getGates();
if (gates.isEmpty()) {
return answer;
}
boolean simple = gates.stream().allMatch(g -> g instanceof SingleQubitGate);
if (simple) {
return answer;
}
Gate targetGate = gates.get(0);
// if only 1 gate, which is an oracle, return as well
if ((gates.size() == 1) && (targetGate instanceof Oracle || targetGate instanceof Swap)) {
return answer;
}
// at least one non-singlequbitgate
List<Gate> firstGates = new ArrayList<>();
for (Gate gate : gates) {
if (gate.getHighestAffectedQubitIndex() > nqubit) {
throw new IllegalArgumentException("Only "+nqubit+" qubits available while Gate "+gate+" requires qubit "+gate.getHighestAffectedQubitIndex());
}
if (gate instanceof ProbabilitiesGate) {
s.setInformalStep(true);
return answer;
}
if (gate instanceof BlockGate) {
if (gate instanceof ControlledBlockGate) {
processBlockGate((ControlledBlockGate) gate, answer);
}
firstGates.add(gate);
} else if (gate instanceof SingleQubitGate) {
firstGates.add(gate);
} else if (gate instanceof TwoQubitGate) {
TwoQubitGate tqg = (TwoQubitGate) gate;
int first = tqg.getMainQubitIndex();
int second = tqg.getSecondQubitIndex();
if ((first >= nqubit) || (second >= nqubit)) {
throw new IllegalArgumentException("Step " + s + " uses a gate with invalid index " + first + " or " + second);
}
if (first == second + 1) {
firstGates.add(gate);
} else {
if (first == second) {
throw new RuntimeException("Wrong gate, first == second for " + gate);
}
if (first > second) {
PermutationGate pg = new PermutationGate(first - 1, second, nqubit);
Step prePermutation = new Step(pg);
Step postPermutation = new Step(pg);
answer.add(0, prePermutation);
answer.add(postPermutation);
postPermutation.setComplexStep(s.getIndex());
s.setComplexStep(-1);
} else {
PermutationGate pg = new PermutationGate(first, second, nqubit);
Step prePermutation = new Step(pg);
Step prePermutationInv = new Step(pg);
int realStep = s.getIndex();
s.setComplexStep(-1);
answer.add(0, prePermutation);
answer.add(prePermutationInv);
Step postPermutation = new Step();
Step postPermutationInv = new Step();
if (first != second - 1) {
PermutationGate pg2 = new PermutationGate(second - 1, first, nqubit);
postPermutation.addGate(pg2);
postPermutationInv.addGate(pg2);
answer.add(1, postPermutation);
answer.add(3, postPermutationInv);
}
prePermutationInv.setComplexStep(realStep);
}
}
} else if (gate instanceof ThreeQubitGate) {
ThreeQubitGate tqg = (ThreeQubitGate) gate;
int first = tqg.getMainQubit();
int second = tqg.getSecondQubit();
int third = tqg.getThirdQubit();
int sFirst = first;
int sSecond = second;
int sThird = third;
if ((first == second + 1) && (second == third + 1)) {
firstGates.add(gate);
} else {
int p0idx = 0;
int maxs = Math.max(second, third);
if (first < maxs) {
PermutationGate pg = new PermutationGate(first, maxs, nqubit);
Step prePermutation = new Step(pg);
Step postPermutation = new Step(pg);
answer.add(p0idx, prePermutation);
answer.add(answer.size() - p0idx, postPermutation);
p0idx++;
postPermutation.setComplexStep(s.getIndex());
s.setComplexStep(-1);
sFirst = maxs;
if (second > third) {
sSecond = first;
} else {
sThird = first;
}
}
if (sSecond != sFirst - 1) {
PermutationGate pg = new PermutationGate(sFirst - 1, sSecond, nqubit);
Step prePermutation = new Step(pg);
Step postPermutation = new Step(pg);
answer.add(p0idx, prePermutation);
answer.add(answer.size() - p0idx, postPermutation);
p0idx++;
postPermutation.setComplexStep(s.getIndex());
s.setComplexStep(-1);
sSecond = sFirst - 1;
}
if (sThird != sFirst - 2) {
PermutationGate pg = new PermutationGate(sFirst - 2, sThird, nqubit);
Step prePermutation = new Step(pg);
Step postPermutation = new Step(pg);
answer.add(p0idx, prePermutation);
answer.add(answer.size() - p0idx, postPermutation);
p0idx++;
postPermutation.setComplexStep(s.getIndex());
s.setComplexStep(-1);
sThird = sFirst - 2;
}
}
} else {
throw new RuntimeException("Gate must be SingleQubit or TwoQubit");
}
}
return answer;
}
/**
* <p>printMatrix.</p>
*
* @param a an array of {@link org.redfx.strange.Complex} objects
*/
public static void printMatrix(Complex[][] a) {
for (int i = 0; i < a.length; i++) {
StringBuilder sb = new StringBuilder();
for (int j = 0; j < a[i].length; j++) {
sb.append(a[i][j]).append(" ");
}
System.out.println("m[" + i + "]: " + sb);
}
}
/**
* <p>getInverseModulus.</p>
*
* @param a a int
* @param b a int
* @return a int
*/
public static int getInverseModulus(int a, int b) {
int r0 = a;
int r1 = b;
int r2 = 0;
int s0 = 1;
int s1 = 0;
int s2 = 0;
while (r1 != 1) {
int q = r0 / r1;
r2 = r0 % r1;
s2 = s0 - q * s1;
r0 = r1;
r1 = r2;
s0 = s1;
s1 = s2;
}
return s1 > 0 ? s1 : s1 + b;
}
/**
* <p>gcd.</p>
*
* @param a a int
* @param b a int
* @return a int
*/
public static int gcd(int a, int b) {
int x = a > b ? a : b;
int y = x == a ? b : a;
int z = 0;
while (y != 0) {
z = x % y;
x = y;
y = z;
}
return x;
}
/**
* <p>fraction.</p>
*
* @param p a int
* @param max a int
* @return a int
*/
public static int fraction(int p, int max) {
int length = (int) Math.ceil(Math.log(max) / Math.log(2));
int offset = length;
int dim = 1 << offset;
double r = (double) p / dim + .000001;
int period = Computations.fraction(r, max);
return period;
}
/**
* <p>fraction.</p>
*
* @param d a double
* @param max a int
* @return a int
*/
public static int fraction(double d, int max) {
double EPS = 1e-15;
int answer = -1;
int h = 0;
int k = -1;
int a = (int) d;
double r = d - a;
int h_2 = 0;
int h_1 = 1;
int k_2 = 1;
int k_1 = 0;
while ((k < max) && (r > EPS)) {
h = a * h_1 + h_2;
k = a * k_1 + k_2;
h_2 = h_1;
h_1 = h;
k_2 = k_1;
k_1 = k;
double rec = 1 / r;
a = (int) rec;
r = rec - a;
}
return k_2;
}
/**
* <p>createIdentity.</p>
*
* @param dim a int
* @return an array of {@link org.redfx.strange.Complex} objects
*/
public static Complex[][] createIdentity(int dim) {
Complex[][] matrix = new Complex[dim][dim];
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
matrix[i][j] = (i == j) ? Complex.ONE : Complex.ZERO;
}
}
return matrix;
}
/**
* <p>printMemory.</p>
*/
public static void printMemory() {
if (!debug) {
return;
}
Runtime rt = Runtime.getRuntime();
long fm = rt.freeMemory() / 1024;
long mm = rt.maxMemory() / 1024;
long tm = rt.totalMemory() / 1024;
long um = tm - fm;
System.err.println("free = " + fm + ", mm = " + mm + ", tm = " + tm + ", used = " + um);
/*
System.err.println("now gc...");
System.gc();
fm = rt.freeMemory()/1024;
mm = rt.maxMemory()/1024;
tm = rt.totalMemory()/1024;
um = tm - fm;
System.err.println("free = "+fm+", mm = "+mm+", tm = "+tm+", used = "+um);
*/
}
/**
* <p>permutateVector.</p>
*
* @param vector an array of {@link org.redfx.strange.Complex} objects
* @param a a int
* @param b a int
* @return an array of {@link org.redfx.strange.Complex} objects
*/
public static Complex[] permutateVector(Complex[] vector, int a, int b) {
int amask = 1 << a;
int bmask = 1 << b;
if ((amask >= vector.length) || (bmask >= vector.length)) {
throw new IllegalArgumentException("Can not permutate element "+a+" and "+b+" of vector sized "+vector.length);
}
int dim = vector.length;
Complex[] answer = new Complex[dim];
for (int i = 0; i < dim; i++) {
int j = i;
int x = (amask & i) / amask;
int y = (bmask & i) / bmask;
if (x != y) {
j ^= amask;
j ^= bmask;
}
answer[i] = vector[j];
}
return answer;
}
static int nested = 0; // allows us to e.g. show only 2 nested steps
/**
* <p>calculateNewState.</p>
*
* @param gates a {@link java.util.List} object
* @param vector an array of {@link org.redfx.strange.Complex} objects
* @param length a int
* @return an array of {@link org.redfx.strange.Complex} objects
*/
public static Complex[] calculateNewState(List<Gate> gates, Complex[] vector, int length) {
if (containsImmediateMeasurementOnly(gates)) {
return doImmediateMeasurement(gates, vector, length);
}
nested++;
Complex[] answer = getNextProbability(getAllGates(gates, length), vector);
nested--;
return answer;
}
private static Complex[] getNextProbability(List<Gate> gates, Complex[] v) {
Complex[] answer = new Complex[v.length];
boolean onlyIdentity = (gates.size() == 1 || gates.subList(1, gates.size()-1).stream().allMatch(g -> g instanceof Identity));
if (onlyIdentity && gates.get(0) instanceof Swap swap) {
return processSwapGate(swap, v);
}
return getNextProbability2(gates, v);
}
static Complex[] processSwapGate(Swap swap, Complex[] v) {
Complex[] result = new Complex[v.length];
int b0 = swap.getMainQubitIndex();
int b1 = swap.getSecondQubitIndex();
for (int i = 0 ; i < v.length; i++) {
result[i] = v[swapBits(i, b0, b1)];
}
return result;
}
private static Complex[] getNextProbability2(List<Gate> gates, Complex[] v) {
Gate gate = gates.get(0);
int nqubits = gate.getSize();
int gatedim = 1 << nqubits;
int size = v.length;
dbg("GETNEXTPROBABILITY asked for size = " + size + " and gates = " + gates+", gatedim = "+gatedim);
if (gates.size() > 1) {
int partdim = size / gatedim;
Complex[] answer = new Complex[size];
List<Gate> nextGates = gates.subList(1, gates.size());
boolean id = true;
for (Gate g : nextGates) {
id = id && (g instanceof Identity);
}
if (id) {
dbg("ONLY IDENTITY!! partdim = "+partdim);
long s0 = System.currentTimeMillis();
long s1 = s0;
for (int j = 0; j < partdim; j++) {
dbg("do part "+j+" from "+partdim);
Complex[] oldv = new Complex[gatedim];
Complex[] newv = new Complex[gatedim];
for (int i = 0; i < gatedim; i++) {
oldv[i] = v[i * partdim + j];
newv[i] = Complex.ZERO;
}
if (gate.hasOptimization()) {
dbg("OPTPART!");
newv = gate.applyOptimize(oldv);
} else {
dbg("GET MATRIX for "+gate);
Complex[][] matrix = gate.getMatrix();
s1 = System.currentTimeMillis();
for (int i = 0; i < gatedim; i++) {
for (int k = 0; k < gatedim; k++) {
newv[i] = newv[i].add(matrix[i][k].mul(oldv[k]));
}
}
}
for (int i = 0; i < gatedim; i++) {
answer[i * partdim + j] = newv[i];
}
dbg("done part");
}
long s2 = System.currentTimeMillis();
return answer;
}
long sm0 = System.currentTimeMillis();
Complex[][] vsub = new Complex[gatedim][partdim];
for (int i = 0; i < gatedim; i++) {
Complex[] vorig = new Complex[partdim];
for (int j = 0; j < partdim; j++) {
vorig[j] = v[j + i * partdim];
}
vsub[i] = getNextProbability(nextGates, vorig);
}
long s0 = System.currentTimeMillis();
Complex[][] matrix = gate.getMatrix();
long s1 = System.currentTimeMillis();
for (int i = 0; i < gatedim; i++) {
for (int j = 0; j < partdim; j++) {
answer[j + i * partdim] = Complex.ZERO;
for (int k = 0; k < gatedim; k++) {
answer[j + i * partdim] = answer[j + i * partdim].add(matrix[i][k].mul(vsub[k][j]));
}
}
}
long s2 = System.currentTimeMillis();
return answer;
} else {
if (gatedim != size) {
System.err.println("problem with matrix for gate " + gate);
throw new IllegalArgumentException("wrong matrix size " + gatedim + " vs vector size " + v.length);
}
if (gate.hasOptimization()) {
return gate.applyOptimize(v);
} else {
Complex[][] matrix = gate.getMatrix();
Complex[] answer = new Complex[size];
for (int i = 0; i < size; i++) {
answer[i] = Complex.ZERO;
for (int j = 0; j < size; j++) {
answer[i] = answer[i].add(matrix[i][j].mul(v[j]));
}
}
return answer;
}
}
}
/**
* Check if the gates operates on qubits that are part of this Program.
* e.g. if a 3-sized gate is applied to qubit 2 in a 3-qubit circuit, this
* will throw an IllegalArgumentException.
*/
private static void validateGates(List<Gate> gates, int nQubits) {
for (Gate gate : gates) {
if (gate.getHighestAffectedQubitIndex() >= nQubits) {
throw new IllegalArgumentException
("Gate "+gate+" operates on qubit "+ gate.getHighestAffectedQubitIndex()+" but we have only "+nQubits+" qubits.");
}
}
}
private static List<Gate> getAllGates(List<Gate> gates, int nQubits) {
validateGates(gates, nQubits);
dbg("getAllGates, orig = " + gates);
List<Gate> answer = new ArrayList<>();
int idx = nQubits - 1;
while (idx >= 0) {
final int cnt = idx;
Gate myGate = gates.stream()
.filter(
gate -> gate.getHighestAffectedQubitIndex() == cnt)
.findFirst()
.orElse(new Identity(idx));
dbg("stepmatrix, cnt = " + cnt + ", idx = " + idx + ", myGate = " + myGate);
answer.add(myGate);
if (myGate instanceof BlockGate) {
BlockGate sqg = (BlockGate) myGate;
idx = idx - sqg.getSize() + 1;
dbg("processed blockgate, size = " + sqg.getSize() + ", idx = " + idx);
}
if (myGate instanceof TwoQubitGate) {
idx--;
}
if (myGate instanceof ThreeQubitGate) {
idx = idx - 2;
}
if (myGate instanceof PermutationGate) {
throw new RuntimeException("No perm allowed ");
}
if (myGate instanceof Oracle) {
idx = 0;
}
idx--;
}
return answer;
}
private static void processBlockGate(ControlledBlockGate gate, ArrayList<Step> answer) {
Step master = answer.get(answer.size() -1);
gate.calculateHighLow();
int low = gate.getLow();
int control = gate.getControlQubit();
int idx = gate.getMainQubitIndex();
int high = control;
int size = gate.getSize();
int gap = control - idx;
List<PermutationGate> perm = new LinkedList<>();
Block block = gate.getBlock();
int bs = block.getNQubits();
if (control > idx) {
if (gap < bs) {
throw new IllegalArgumentException("Can't have control at " + control + " for gate with size " + bs + " starting at " + idx);
}
low = idx;
if (gap > bs) {
high = control;
size = high - low + 1;
PermutationGate pg = new PermutationGate(control, control - gap + bs, low + size);
perm.add(pg);
}
} else {
low = control;
high = idx + bs - 1;
size = high - low + 1;
// gate.correctHigh(low+bs);
for (int i = low; i < low + size - 1; i++) {
PermutationGate pg = new PermutationGate(i, i + 1, low + size);
perm.add(0, pg);
}
}
for (int i = 0; i < perm.size(); i++) {
PermutationGate pg = perm.get(i);
Step lpg = new Step(pg);
if (i < perm.size()-1) {
lpg.setComplexStep(-1);
} else {
// the complex step should be the last part of the step
lpg.setComplexStep(master.getComplexStep());
master.setComplexStep(-1);
}
answer.add(lpg);
answer.add(0, new Step(pg));
}
}
// TODO: make this a utility method
/**
* <p>calculateQubitStatesFromVector.</p>
*
* @param vectorresult an array of {@link org.redfx.strange.Complex} objects
* @return an array of doubles
*/
public static double[] calculateQubitStatesFromVector(Complex[] vectorresult) {
int nq = (int) Math.round(Math.log(vectorresult.length) / Math.log(2));
double[] answer = new double[nq];
int ressize = 1 << nq;
for (int i = 0; i < nq; i++) {
int pw = i;//nq - i - 1;
int div = 1 << pw;
for (int j = 0; j < ressize; j++) {
int p1 = j / div;
if (p1 % 2 == 1) {
answer[i] = answer[i] + vectorresult[j].abssqr();
}
}
}
return answer;
}
/**
* Evaluates the first ImmediateMeasurement gate in the supplied list.
* A measurement is done based on existing probabilities, and the resulting vector is
* post-conditioned for this value. In case a callback is provided when creating the
* ImmediateMeasurement gate, that will be called with the resulted measurement.
* @param gates
* @param vector
* @param length
* @return
*/
static Complex[] doImmediateMeasurement(List<Gate> gates, Complex[] vector, int length) {
int size = vector.length;
Gate gate = gates.stream().filter(g -> g instanceof ImmediateMeasurement).findFirst()
.orElseThrow(() -> new IllegalArgumentException("Need an ImmediateMeasurement gate"));
ImmediateMeasurement mGate = (ImmediateMeasurement)gate;
int idx = mGate.getMainQubitIndex();
double[] p = new double[2];
for (int i = 0; i < size; i++) {
p[(i/(1 <<idx))%2] += vector[i].abssqr();
}
double rnd = Math.random();
int pick = rnd > p[0] ? 1 : 0;
Complex[] answer = new Complex[size];
for (int i = 0; i < size; i++) {
if (pick == (i/(1 <<idx))%2) {
answer[i] = vector[i].mul(1/Math.sqrt(p[pick]));
} else {
answer[i] = Complex.ZERO;
}
}
Consumer<Boolean> consumer = mGate.getConsumer();
if (consumer != null) consumer.accept(pick != 0);
return answer;
}
/**
* Checks if the list of gates contain one or more ImmediateMeasurement gates
* and no other gates apart from the Identity gate
* @param gates
* @return
*/
static boolean containsImmediateMeasurementOnly(List<Gate> gates) {
return gates.stream().anyMatch(g -> g instanceof ImmediateMeasurement) &&
gates.stream().allMatch(g -> g instanceof ImmediateMeasurement || g instanceof Identity);
}
/**
* Swap the value of bits b0 and b1 in the value val.
* @param val an integer value, containing 32 bits
* @param b0 position of bit 0
* @param b1 position of bit 1
* @return a new integer value with bits 0 and 1 swapped
*/
public static int swapBits(int val, int b0, int b1) {
int bit0Val = (val >> b0) & 1;
int bit1Val = (val >> b1) & 1;
if (bit0Val != bit1Val) {
int bitMask = (1 << b0) | (1 << b1);
val ^= bitMask;
}
return val;
}
}