Source code for scqubits.utils.misc

# misc.py
#
# This file is part of scqubits: a Python package for superconducting qubits,
# Quantum 5, 583 (2021). https://quantum-journal.org/papers/q-2021-11-17-583/
#
#    Copyright (c) 2019 and later, Jens Koch and Peter Groszkowski
#    All rights reserved.
#
#    This source code is licensed under the BSD-style license found in the
#    LICENSE file in the root directory of this source tree.
############################################################################

import ast
import functools
import platform
import warnings

import inspect
from collections.abc import Sequence
from distutils.version import StrictVersion
from io import StringIO
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import matplotlib
import numpy as np
import qutip as qt
import scipy as sp
from matplotlib import get_backend as get_matplotlib_backend

import scqubits.settings
from scqubits.settings import IN_IPYTHON

if IN_IPYTHON:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm


[docs]def process_which(which: Union[int, Iterable[int]], max_index: int) -> List[int]: """Processes different ways of specifying the selection of wanted eigenvalues/eigenstates. Parameters ---------- which: single index or tuple/list of integers indexing the eigenobjects. If 'which' is -1, all indices up to the max_index limit are included. max_index: maximum index value Returns ------- indices """ if isinstance(which, int): if which == -1: return list(range(max_index)) return [which] return list(which)
[docs]def make_bare_labels(subsystem_count: int, *args) -> Tuple[int, ...]: """ For two given subsystem states, return the full-system bare state label obtained by placing all remaining subsys_list in their ground states. Parameters ---------- subsystem_count: number of subsys_list inside Hilbert space *args: each argument is a tuple of the form (subsys_index, label) Returns ------- Suppose there are 5 subsys_list in total. Let (subsys_index1=0, label1=3), (subsys_index2=2, label2=1). Then the returned bare-state tuple is: (3,0,1,0,0) """ bare_labels = [0] * subsystem_count for subsys_index, label in args: bare_labels[subsys_index] = label return tuple(bare_labels)
[docs]def drop_private_keys(full_dict: Dict[str, Any]) -> Dict[str, Any]: """Filter for entries in the full dictionary that have numerical values""" return {key: value for key, value in full_dict.items() if key[0] != "_"}
[docs]class InfoBar: """Static "progress" bar used whenever multiprocessing is involved. Parameters ---------- desc: Description text to be displayed on the static information bar. num_cpus: Number of CPUS/cores employed in underlying calculation. """ def __init__(self, desc: str, num_cpus: int) -> None: self.desc = desc self.num_cpus = num_cpus self.tqdm_bar = None def __enter__(self) -> None: self.tqdm_bar = tqdm( total=0, disable=(self.num_cpus == 1) or scqubits.settings.PROGRESSBAR_DISABLED, leave=False, desc=self.desc, bar_format="{desc}", ) def __exit__(self, *args) -> None: if self.tqdm_bar: self.tqdm_bar.close()
[docs]class Required: """Decorator class, ensuring that a given requirement or set of requirements is fulfilled. Parameters ---------- dict {str: bool} All bool conditions have to be True to pass. The provided str keys are used to display information on what condition is failing. """ def __init__(self, **requirements) -> None: self.requirements_bools = list(requirements.values()) self.requirements_names = list(requirements.keys()) self.missing_imports = [name for name in requirements if not requirements[name]] def __call__(self, func: Callable, *args, **kwargs) -> Callable: @functools.wraps(func) def decorated_func(*args, **kwargs): if all(self.requirements_bools): return func(*args, **kwargs) else: with warnings.catch_warnings(): if self.missing_imports == ["ipyvuetify"]: warnings.warn( "Starting with v3.2, scqubits uses the optional package 'ipyuetify' for graphical " "user interfaces. To use this functionality, add the package via " "`conda install -c conda-forge ipyvuetify` or `pip install ipyvuetify`.\n" "For use with jupyter lab, additionally execute " "`jupyter labextension install jupyter-vuetify`.\n", category=Warning, ) else: warnings.warn( "use of this method requires the optional package(s):" " {}. If you wish to use this functionality, the corresponding" " package(s) must be installed manually. (Installation via `conda" " install -c conda-forge <packagename>` or `pip install" " <packagename>` is recommended.)".format( self.requirements_names ), category=Warning, ) return decorated_func
[docs]def check_sync_status(func: Callable) -> Callable: @functools.wraps(func) def wrapper(self, *args, **kwargs): if self._out_of_sync and not self._out_of_sync_warning_issued: with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( "[scqubits] Some quantum system parameters have been changed and" " generated spectrum data could be outdated, potentially leading to" " incorrect results. Spectral data can be refreshed via" " <HilbertSpace>.generate_lookup() or <ParameterSweep>.run()", Warning, ) self._out_of_sync_warning_issued = True return func(self, *args, **kwargs) return wrapper
[docs]def check_sync_status_circuit(func: Callable) -> Callable: @functools.wraps(func) def wrapper(self, *args, **kwargs): # update the circuit if necessary if (self._user_changed_parameter) or ( self.hierarchical_diagonalization and (self._out_of_sync or len(self.affected_subsystem_indices) > 0) ): self.update() return func(self, *args, **kwargs) return wrapper
[docs]def check_lookup_exists(func: Callable) -> Callable: @functools.wraps(func) def wrapper(self, *args, **kwargs): if not self._lookup_exists: raise Exception( "Lookup data not found. For HilbertSpace: data must be generated with .generate_lookup(). " "For ParameterSweep: data should be automatically generated unless disabled manually. In " "the latter case, apply .run()." ) return func(self, *args, **kwargs) return wrapper
[docs]class DeprecationMessage: """Decorator class, producing an adjustable warning and info upon usage of the decorated function. Parameters ---------- warning_message: Warnings message to be sent upon decorated (deprecated) routing """ def __init__(self, warning_msg: str) -> None: self.warning_msg = warning_msg def __call__(self, func: Callable, *args, **kwargs) -> Callable: @functools.wraps(func) def decorated_func(*args, **kwargs): warnings.warn(self.warning_msg, FutureWarning) return func(*args, **kwargs) return decorated_func
[docs]def to_expression_or_string(string_expr: str) -> Any: try: return ast.literal_eval(string_expr) except ValueError: return string_expr
[docs]def remove_nones(dict_data: Dict[str, Any]) -> Dict[str, Any]: return {key: value for key, value in dict_data.items() if value is not None}
[docs]def qt_ket_to_ndarray(qobj_ket: qt.Qobj) -> np.ndarray: # Qutip's `.eigenstates()` returns an object-valued ndarray, each idx_entry of which # is a Qobj ket. return ( qobj_ket.data.as_ndarray() if qt.__version__ >= "5.0.0" else qobj_ket.data.toarray() )
[docs]def Qobj_to_scipy_csc_matrix(qobj_array: qt.Qobj) -> sp.sparse.csc_matrix: return ( qobj_array.to("csr").data.as_scipy().tocsc() if qt.__version__ >= "5.0.0" else qobj_array.data.tocsc() )
[docs]def get_shape(lst, shape=()): """ returns the shape of nested lists similarly to numpy's shape. :param lst: the nested list :param shape: the shape up to the current recursion depth :return: the shape including the current depth (finally this will be the full depth) """ if not isinstance(lst, Sequence): # base case return shape # peek ahead and assure all lists in the next depth # have the same length if isinstance(lst[0], Sequence): l = len(lst[0]) if not all(len(item) == l for item in lst): msg = "not all lists have the same length" raise ValueError(msg) shape += (len(lst),) # recurse shape = get_shape(lst[0], shape) return shape
[docs]def tuple_to_short_str(the_tuple: tuple) -> str: short_str = "" for entry in the_tuple: short_str += str(entry) + "," return short_str[:-1]
[docs]def to_list(obj: Any) -> List[Any]: """ Converts object to a list: if the input is already a list, it will return the same list. If the input is a numpy array, it will convert to a python list. Otherwise, it will return the original object in a single-elemented python list. Parameters ---------- obj: Specify the type of object that is being passed into the function Returns ------- a list of the object passed in """ if isinstance(obj, list): return obj if isinstance(obj, np.ndarray): return obj.tolist() return [obj]
[docs]def about(print_info=True) -> Optional[str]: """Prints or returns a string with basic information about scqubits as well as installed version of various packages that scqubits depends on. Parameters ---------- print_info: bool Flag that determines if string with information should be printed (if True) or returned (if False). Returns ------- None or str """ from scqubits import __version__ fs = StringIO() fs.write("scqubits: a Python library for simulating superconducting qubits\n") fs.write("****************************************************************\n") fs.write("Developed by J. Koch, P. Groszkowski\n") fs.write("Main Github page: https://github.com/scqubits/scqubits\n") fs.write( "Online documentation page: https://scqubits.readthedocs.io/en/latest/\n\n" ) fs.write("scqubits version: {}\n".format(__version__)) fs.write("numpy version: {}\n".format(np.__version__)) fs.write("scipy version: {}\n".format(sp.__version__)) fs.write("QuTiP version: {}\n".format(qt.__version__)) fs.write( "Platform: {} ({})\n".format(platform.system(), platform.machine()) ) if print_info: print(fs.getvalue()) return None else: return fs.getvalue()
[docs]def cite(print_info=True): """Prints or returns a string with scqubits citation information. Parameters ---------- print_info: bool Flag that determines if string with information should be printed (if True) or returned (if False). Returns ------- None or str """ fs = StringIO() fs.write("Peter Groszkowski and Jens Koch,\n") fs.write("'scqubits: a Python package for superconducting qubits'\n") fs.write("Quantum 5, 583 (2021).\n") fs.write("https://quantum-journal.org/papers/q-2021-11-17-583/\n") if print_info: print(fs.getvalue()) return None else: return fs.getvalue()
[docs]def is_string_float(the_string: str) -> bool: try: float(the_string) return True except ValueError: return False
[docs]def is_string_int(the_string: str) -> bool: try: int(the_string) return True except ValueError: return False
[docs]def list_intersection(list1: list, list2: list) -> list: return [item for item in list1 if item in list2]
[docs]def flatten_list(nested_list): """ Flattens a list of lists once, not recursive. Parameters ---------- nested_list: A list of lists, which can hold any class instance. Returns ------- Flattened list of objects """ return functools.reduce(lambda a, b: a + b, nested_list)
[docs]def flatten_list_recursive(some_list: list) -> list: """ Flattens a list of lists recursively. Parameters ---------- some_list: A list of lists, which can hold any class instance. Returns ------- Flattened list of objects """ if some_list == []: return some_list if isinstance(some_list[0], list): return flatten_list_recursive(some_list[0]) + flatten_list_recursive( some_list[1:] ) return some_list[:1] + flatten_list_recursive(some_list[1:])
[docs]def unique_elements_in_list(list_object: list) -> list: """ Returns a list of all the unique elements in the list Parameters ---------- list_object : A list of any objects """ unique_list = [] [ unique_list.append(element) for element in list_object if element not in unique_list ] return unique_list
[docs]def number_of_lists_in_list(list_object: list) -> int: """ Takes a list as an argument and returns the number of lists in that list. (Counts lists at root level only, no recursion.) Parameters ---------- list_object: List to be analyzed Returns ------- The number of lists in the list """ return sum([1 for element in list_object if type(element) == list])
[docs]def check_matplotlib_compatibility(): if _HAS_WIDGET_BACKEND and StrictVersion(matplotlib.__version__) < StrictVersion( "3.5.1" ): warnings.warn( "The widget backend requires Matplotlib >=3.5.1 for proper functioning", UserWarning, )
[docs]def inspect_public_API( module: Any, public_names: List[str] = [], private_names: List[str] = [], ) -> List[str]: """ Find all public names in a module. Parameters ---------- module: Module to be inspected public_names: Names that have already been found / manually be set to public private_names: Names that should be excluded from the public API """ for name, obj in inspect.getmembers(module): if name.startswith("_") or name in public_names or name in private_names: continue if inspect.isclass(obj) or inspect.isfunction(obj) or inspect.ismodule(obj): public_names.append(name) elif not callable(obj) and name.isupper(): # constants public_names.append(name) return public_names
MATPLOTLIB_WIDGET_BACKEND = "module://ipympl.backend_nbagg" _HAS_WIDGET_BACKEND = get_matplotlib_backend() == MATPLOTLIB_WIDGET_BACKEND