"""Package teneva, module cross: construct TT-tensor, using TT-cross.
This module contains the function "cross" which computes the TT-approximation
for implicit tensor given functionally by the rank-adaptive multidimensional
cross approximation method in the TT-format (TT-cross).
"""
import numpy as np
import teneva
from time import perf_counter as tpc
[docs]def cross(f, Y0, m=None, e=None, nswp=None, tau=1.1, dr_min=1, dr_max=1,
tau0=1.05, k0=100, info={}, cache=None, I_vld=None, y_vld=None,
e_vld=None, cb=None, func=None, m_cache_scale=5, log=False):
"""Compute the TT-approximation for implicit tensor given functionally.
This function computes the TT-approximation for implicit tensor given
functionally by the rank-adaptive multidimensional cross approximation
method in the TT-format (TT-cross).
Args:
f (function): function f(I) which computes tensor elements for the
given set of multi-indices I, where I is a 2D np.ndarray of the
shape [samples, dimensions]. The function should return 1D
np.ndarray of the length equals to samples, which relates to the
values of the target function for all provided samples. If the
function returns None, then the algorithm will be interrupted and
the current result will be returned (in the info dictionary, in
this case, the stop type of the algorithm will be "func").
Y0 (list): TT-tensor, which is the initial approximation for algorithm.
m (int): optional limit on the maximum number of requests to the
objective function (> 0). If specified, then the total number of
requests will not exceed this value. Note that the actual number of
requests may be less, since the values are requested in batches.
e (float): optional algorithm convergence criterion (> 0). If between
iterations the relative rate of solution change is less than this
value, then the operation of the algorithm will be interrupted.
nswp (int): optional maximum number of iterations (sweeps) of the
algorithm (>= 0). One sweep corresponds to a complete pass of all
tensor TT-cores from left to right and then from right to left. If
nswp = 0, then only "maxvol-preiteration" will be performed.
tau (float): accuracy parameter (>= 1) for the algorithm "maxvol_rect"
(see "maxvol_rect" function for more details).
dr_min (int): minimum number of added rows in the process of adaptively
increasing the TT-rank of the approximation using the algorithm
"maxvol_rect" (see "maxvol_rect" function for more details). Note
that dr_min should be no bigger than dr_max.
dr_max (int): maximum number of added rows in the process of adaptively
increasing the TT-rank of the approximation using the algorithm
"maxvol_rect" (see "maxvol_rect" function for more details). Note
that dr_max should be no less than dr_min. If dr_max = 0,
then basic maxvol algorithm will be used (rank will be constant).
tau0 (float): accuracy parameter (>= 1) for the algorithm "maxvol" (see
"maxvol" function for more details). It will be used while maxvol
preiterations and while the calls of "maxvol" function from the
"maxvol_rect" algorithm.
k0 (int): maximum number of maxvol iterations (>= 1; see "maxvol"
function for more details). It will be used while maxvol
preiterations and while the calls of "maxvol" function from the
"maxvol_rect" algorithm.
info (dict): an optionally set dictionary, which will be filled with
reference information about the process of the algorithm operation.
At the end of the function work, it will contain parameters: m -
total number of requests to the target function; e - the final
value of the convergence criterion; e_vld - the final error on the
validation dataset; nswp - the real number of performed
iterations (sweeps); m_cache - total number of requests to the
cache; stop - stop type of the algorithm (see note below).
cache (dict): an optionally set dictionary, which will be filled with
requested function values. Since the algorithm sometimes requests
the same tensor indices, the use of such a cache may speed up the
operation of the algorithm if the time to find a value in the cache
is less than the time to calculate the target function.
I_vld (np.ndarray): optional multi-indices for items of validation
dataset in the form of array of the shape [samples_vld, d], where
samples_vld is a size of the validation dataset.
y_vld (np.ndarray): optional values of the tensor for multi-indices
I_vld of validation dataset in the form of array of the shape
[samples_vld].
e_vld (float): optional algorithm convergence criterion (> 0). If
after sweep, the error on the validation dataset is less than this
value, then the operation of the algorithm will be interrupted.
cb (function): optional callback function. It will be called after
every sweep and the accuracy check with the arguments: Y,
info and opts, where Y is the current approximation (TT-tensor),
info is the info dictionary and the dictionary opts contains fields
Ir, Ic, cache and Yold. If the callback returns a true value, then
the algorithm will be stopped (in the info dictionary, in this case,
the stop type of the algorithm will be "cb").
func (function): if this function is set, then it will replace the inner
function _func, which deals with requests to the objective
function f. This argument is used only for internal experiments.
m_cache_scale (int): if the number of requests to the cache is
m_cache_scale times greater than the current number of requests to
the target function, then the algorithm will stop (in the info
dictionary, in this case, the stop type will be "conv").
log (bool): if flag is set, then the information about the progress of
the algorithm will be printed after each sweep.
Returns:
list: TT-Tensor which approximates the implicit tensor.
Note:
Note that at list one of the arguments m / e / nswp / e_vld should be
set by user. The end of the algorithm operation occurs when one of the
following criteria is reached: 1) the target function returns None
instead of a batch of values; 2) the maximum allowable number of the
objective function calls (m) has been done (more precisely, if the next
request will result in exceeding this value, then algorithm will not
perform this new request); 3) the convergence criterion (e) is reached;
4) the maximum number of iterations (nswp) is performed; 5) the
algorithm is already converged (all requested values are in the cache
already); 6) the error on validation dataset I_vld, y_vld is less than
e_vld; 7) the callback function returns true value. The related stop
type (func, m, e, nswp, conv, e_vld or cb) will be written into the
item stop of the info dictionary.
The resulting TT-tensor usually has overestimated ranks, so you should
truncate the result. Use for this Y = teneva.truncate(Y, e) (e.g.,
e = 1.E-8) after this function call.
"""
if m is None and e is None and nswp is None:
if I_vld is None or y_vld is None:
raise ValueError('One of arguments m/e/nswp should be set')
elif e_vld is None:
raise ValueError('One of arguments m/e/e_vld/nswp should be set')
if e_vld is not None and (I_vld is None or y_vld is None):
raise ValueError('Validation dataset is not set')
_time = tpc()
info.update({'r': teneva.erank(Y0), 'e': -1, 'e_vld': -1, 'nswp': 0,
'stop': None, 'm': 0, 'm_cache': 0, 'm_max': int(m) if m else None,
'with_cache': cache is not None})
d = len(Y0)
n = teneva.shape(Y0)
Y = teneva.copy(Y0)
Ig = [teneva._reshape(np.arange(k, dtype=int), (-1, 1)) for k in n]
Ir = [None for i in range(d+1)]
Ic = [None for i in range(d+1)]
R = np.ones((1, 1))
for i in range(d):
G = np.tensordot(R, Y[i], 1)
Y[i], R, Ir[i+1] = _iter(G, Ig[i], Ir[i], tau0=tau0, k0=k0, ltr=True)
Y[d-1] = np.tensordot(Y[d-1], R, 1)
R = np.ones((1, 1))
for i in range(d-1, -1, -1):
G = np.tensordot(Y[i], R, 1)
Y[i], R, Ic[i] = _iter(G, Ig[i], Ic[i+1], tau0=tau0, k0=k0, ltr=False)
Y[0] = np.tensordot(R, Y[0], 1)
info['e_vld'] = teneva.accuracy_on_data(Y, I_vld, y_vld)
teneva._info_appr(info, _time, nswp, e, e_vld, log)
while True:
Yold = teneva.copy(Y)
R = np.ones((1, 1))
for i in range(d):
Z = (func or _func)(f, Ig[i], Ir[i], Ic[i+1], info, cache)
if info['stop']:
Y[i] = np.tensordot(R, Y[i], 1)
info['r'] = teneva.erank(Y)
info['e'] = teneva.accuracy(Y, Yold)
info['e_vld'] = teneva.accuracy_on_data(Y, I_vld, y_vld)
teneva._info_appr(info, _time, nswp, e, e_vld, log)
return Y
Y[i], R, Ir[i+1] = _iter(Z, Ig[i], Ir[i],
tau, dr_min, dr_max, tau0, k0, ltr=True)
Y[d-1] = np.tensordot(Y[d-1], R, 1)
R = np.ones((1, 1))
for i in range(d-1, -1, -1):
Z = (func or _func)(f, Ig[i], Ir[i], Ic[i+1], info, cache)
if info['stop']:
Y[i] = np.tensordot(Y[i], R, 1)
info['r'] = teneva.erank(Y)
info['e'] = teneva.accuracy(Y, Yold)
info['e_vld'] = teneva.accuracy_on_data(Y, I_vld, y_vld)
teneva._info_appr(info, _time, nswp, e, e_vld, log)
return Y
Y[i], R, Ic[i] = _iter(Z, Ig[i], Ic[i+1],
tau, dr_min, dr_max, tau0, k0, ltr=False)
Y[0] = np.tensordot(R, Y[0], 1)
info['nswp'] += 1
info['r'] = teneva.erank(Y)
info['e'] = teneva.accuracy(Y, Yold)
info['e_vld'] = teneva.accuracy_on_data(Y, I_vld, y_vld)
if info['m_cache'] > m_cache_scale * info['m']:
info['stop'] = 'conv'
if cb:
opts = {'Yold': Yold, 'Ir': Ir, 'Ic': Ic, 'cache': cache}
if cb(Y, info, opts) is True:
info['stop'] = info['stop'] or 'cb'
if teneva._info_appr(info, _time, nswp, e, e_vld, log):
return Y
def _func(f, Ig, Ir, Ic, info, cache=None):
n = Ig.shape[0]
r1 = Ir.shape[0] if Ir is not None else 1
r2 = Ic.shape[0] if Ic is not None else 1
I = np.kron(np.kron(teneva._ones(r2), Ig), teneva._ones(r1))
if Ir is not None:
Ir_ = np.kron(teneva._ones(n * r2), Ir)
I = np.hstack((Ir_, I))
if Ic is not None:
Ic_ = np.kron(Ic, teneva._ones(r1 * n))
I = np.hstack((I, Ic_))
y = _func_eval(f, I, info, cache)
if y is not None:
return teneva._reshape(y, (r1, n, r2))
def _func_eval(f, I, info, cache=None):
if cache is None:
if info['m_max'] is not None and info['m'] + len(I) > info['m_max']:
info['stop'] = 'm'
return
y = f(I)
if y is None:
info['stop'] = 'func'
return
info['m'] += len(I)
return np.array(y, dtype=float)
I_new = np.array([i for i in I if tuple(i) not in cache])
if len(I_new):
if info['m_max'] is not None and info['m'] + len(I_new) > info['m_max']:
info['stop'] = 'm'
return
y_new = f(I_new)
if y_new is None:
info['stop'] = 'func'
return
for k, i in enumerate(I_new):
cache[tuple(i)] = float(y_new[k])
info['m'] += len(I_new)
info['m_cache'] += len(I) - len(I_new)
return np.array([cache[tuple(i)] for i in I], dtype=float)
def _iter(Z, Ig, I, tau=1.1, dr_min=0, dr_max=0, tau0=1.05, k0=100, ltr=True):
r1, n, r2 = Z.shape
if ltr:
Z = teneva._reshape(Z, (r1 * n, r2))
else:
Z = teneva._reshape(Z, (r1, n * r2)).T
Q, R = np.linalg.qr(Z)
ind, B = teneva._maxvol(Q, tau, dr_min, dr_max, tau0, k0)
if ltr:
G = teneva._reshape(B, (r1, n, -1))
R = Q[ind, :] @ R
I_new = np.kron(Ig, teneva._ones(r1))
if I is not None:
I_old = np.kron(teneva._ones(n), I)
I_new = np.hstack((I_old, I_new))
else:
G = teneva._reshape(B.T, (-1, n, r2))
R = (Q[ind, :] @ R).T
I_new = np.kron(teneva._ones(r2), Ig)
if I is not None:
I_old = np.kron(I, teneva._ones(n))
I_new = np.hstack((I_new, I_old))
return G, R, I_new[ind, :]