Source code for pyquil.wavefunction

##############################################################################
# Copyright 2018 Rigetti Computing
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
##############################################################################
"""
Module containing the Wavefunction object and methods for working with wavefunctions.
"""
import itertools
from typing import Dict, Iterator, List, Optional, Sequence, cast

import numpy as np

OCTETS_PER_DOUBLE_FLOAT = 8
OCTETS_PER_COMPLEX_DOUBLE = 2 * OCTETS_PER_DOUBLE_FLOAT


[docs]class Wavefunction(object): """ Encapsulate a wavefunction representing a quantum state as returned by :py:class:`~pyquil.api.WavefunctionSimulator`. .. note:: The elements of the wavefunction are ordered by bitstring. E.g., for two qubits the order is ``00, 01, 10, 11``, where the the bits **are ordered in reverse** by the qubit index, i.e., for qubits 0 and 1 the bitstring ``01`` indicates that qubit 0 is in the state 1. See also :ref:`the related documentation section in the WavefunctionSimulator Overview <basis_ordering>`. """ def __init__(self, amplitude_vector: np.ndarray): """ Initializes a wavefunction :param amplitude_vector: A numpy array of complex amplitudes """ if len(amplitude_vector) == 0 or len(amplitude_vector) & (len(amplitude_vector) - 1) != 0: raise TypeError("Amplitude vector must have a length that is a power of two") self.amplitudes = np.asarray(amplitude_vector) sumprob = np.sum(self.probabilities()) if not np.isclose(sumprob, 1.0): raise ValueError( "The wavefunction is not normalized. " "The probabilities sum to {} instead of 1".format(sumprob) )
[docs] @staticmethod def zeros(qubit_num: int) -> "Wavefunction": """ Constructs the groundstate wavefunction for a given number of qubits. :param qubit_num: :return: A Wavefunction in the ground state """ amplitude_vector = np.zeros(2 ** qubit_num) amplitude_vector[0] = 1.0 return Wavefunction(amplitude_vector)
[docs] @staticmethod def from_bit_packed_string(coef_string: bytes) -> "Wavefunction": """ From a bit packed string, unpacks to get the wavefunction :param coef_string: """ num_cfloat = len(coef_string) // OCTETS_PER_COMPLEX_DOUBLE amplitude_vector = np.ndarray(shape=(num_cfloat,), buffer=coef_string, dtype=">c16") return Wavefunction(amplitude_vector)
def __len__(self) -> int: return len(self.amplitudes).bit_length() - 1 def __iter__(self) -> Iterator[complex]: return cast(Iterator[complex], self.amplitudes.__iter__()) def __getitem__(self, index: int) -> complex: return cast(complex, self.amplitudes[index]) def __setitem__(self, key: int, value: complex) -> None: self.amplitudes[key] = value def __str__(self) -> str: return self.pretty_print(decimal_digits=10)
[docs] def probabilities(self) -> np.ndarray: """Returns an array of probabilities in lexicographical order""" return np.abs(self.amplitudes) ** 2 # type: ignore
[docs] def get_outcome_probs(self) -> Dict[str, float]: """ Parses a wavefunction (array of complex amplitudes) and returns a dictionary of outcomes and associated probabilities. :return: A dict with outcomes as keys and probabilities as values. :rtype: dict """ outcome_dict = {} qubit_num = len(self) for index, amplitude in enumerate(self.amplitudes): outcome = get_bitstring_from_index(index, qubit_num) outcome_dict[outcome] = abs(amplitude) ** 2 return outcome_dict
[docs] def pretty_print_probabilities(self, decimal_digits: int = 2) -> Dict[str, float]: """ TODO: This doesn't seem like it is named correctly... Prints outcome probabilities, ignoring all outcomes with approximately zero probabilities (up to a certain number of decimal digits) and rounding the probabilities to decimal_digits. :param int decimal_digits: The number of digits to truncate to. :return: A dict with outcomes as keys and probabilities as values. """ outcome_dict = {} qubit_num = len(self) for index, amplitude in enumerate(self.amplitudes): outcome = get_bitstring_from_index(index, qubit_num) prob = round(abs(amplitude) ** 2, decimal_digits) if prob != 0.0: outcome_dict[outcome] = prob return outcome_dict
[docs] def pretty_print(self, decimal_digits: int = 2) -> str: """ Returns a string repr of the wavefunction, ignoring all outcomes with approximately zero amplitude (up to a certain number of decimal digits) and rounding the amplitudes to decimal_digits. :param int decimal_digits: The number of digits to truncate to. :return: A string representation of the wavefunction. """ outcome_dict = {} qubit_num = len(self) pp_string = "" for index, amplitude in enumerate(self.amplitudes): outcome = get_bitstring_from_index(index, qubit_num) amplitude = round(amplitude.real, decimal_digits) + round(amplitude.imag, decimal_digits) * 1.0j if amplitude != 0.0: outcome_dict[outcome] = amplitude pp_string += str(amplitude) + "|{}> + ".format(outcome) if len(pp_string) >= 3: pp_string = pp_string[:-3] # remove the dangling + if it is there return pp_string
[docs] def plot(self, qubit_subset: Optional[Sequence[int]] = None) -> None: """ TODO: calling this will error because of matplotlib Plots a bar chart with bitstring on the x axis and probability on the y axis. :param qubit_subset: Optional parameter used for plotting a subset of the Hilbert space. """ import matplotlib.pyplot as plt prob_dict = self.get_outcome_probs() if qubit_subset: sub_dict = {} qubit_num = len(self) for i in qubit_subset: if i > (2 ** qubit_num - 1): raise IndexError("Index {} too large for {} qubits.".format(i, qubit_num)) else: sub_dict[get_bitstring_from_index(i, qubit_num)] = prob_dict[get_bitstring_from_index(i, qubit_num)] prob_dict = sub_dict plt.bar(range(len(prob_dict)), prob_dict.values(), align="center", color="#6CAFB7") plt.xticks(range(len(prob_dict)), prob_dict.keys()) plt.show()
[docs] def sample_bitstrings(self, n_samples: int) -> np.ndarray: """ Sample bitstrings from the distribution defined by the wavefunction. :param n_samples: The number of bitstrings to sample :return: An array of shape (n_samples, n_qubits) """ possible_bitstrings = np.array(list(itertools.product((0, 1), repeat=len(self)))) inds = np.random.choice(2 ** len(self), n_samples, p=self.probabilities()) bitstrings: np.ndarray = possible_bitstrings[inds, :] return bitstrings
[docs]def get_bitstring_from_index(index: int, qubit_num: int) -> str: """ Returns the bitstring in lexical order that corresponds to the given index in 0 to 2^(qubit_num) :param int index: :param int qubit_num: :return: the bitstring :rtype: str """ if index > (2 ** qubit_num - 1): raise IndexError("Index {} too large for {} qubits.".format(index, qubit_num)) return bin(index)[2:].rjust(qubit_num, "0")
def _octet_bits(o: int) -> List[int]: """ Get the bits of an octet. :param o: The octets. :return: The bits as a list in LSB-to-MSB order. """ if not isinstance(o, int): raise TypeError("o should be an int") if not (0 <= o <= 255): raise ValueError("o should be between 0 and 255 inclusive") bits = [0] * 8 for i in range(8): if 1 == o & 1: bits[i] = 1 o = o >> 1 return bits