Source code for teneva.als

"""Package teneva, module als: construct TT-tensor, using TT-ALS.

This module contains the function "als" which computes the TT-approximation for
the tensor by TT-ALS algorithm, using given random samples (i.e., the set of
random tensor multi-indices and related tensor values).

"""
import numpy as np
from opt_einsum import contract
import scipy as sp
import teneva
from time import perf_counter as tpc


[docs]def als(I_trn, y_trn, Y0, nswp=50, e=1.E-16, info={}, *, I_vld=None, y_vld=None, e_vld=None, r=None, r_add=10000, e_adap=1.E-3, lamb=0.001, w=None, cb=None, swap_tol=3, allow_swap=False, allow_skip_cores=False, use_stab=False, log=False, update_sol=None): """Build TT-tensor by TT-ALS method using given random tensor samples. Args: I_trn (np.ndarray): multi-indices for the tensor in the form of array of the shape [samples, d], where d is a number of tensor's dimensions and samples is a size of the train dataset. y_trn (np.ndarray): values of the tensor for multi-indices I_trn in the form of array of the shape [samples]. Y0 (list): TT-tensor, which is the initial approximation for algorithm. nswp (int): number of ALS iterations (sweeps). If e or e_vld parameter is set, then the real number of sweeps may be less (see info dict with the exact number of performed sweeps). e (float): optional algorithm convergence criterion (> 0). If between iterations (sweeps) the relative rate of solution change is less than this value, then the operation of the algorithm will be interrupted. info (dict): an optionally set dictionary, which will be filled with reference information about the process of the algorithm operation. At the end of the function work, it will contain parameters: e - the final value of the convergence criterion; e_vld - the final error on the validation dataset; nswp - the real number of performed iterations (sweeps); stop - stop type of the algorithm (nswp, e, e_vld or cb). I_vld (np.ndarray): optional multi-indices for items of validation dataset in the form of array of the shape [samples_vld, d], where samples_vld is a size of the validation dataset. y_vld (np.ndarray): optional values of the tensor for multi-indices I_vld of validation dataset in the form of array of the shape [samples_vld]. e_vld (float): optional algorithm convergence criterion. If after sweep, the error on the validation dataset is less than this value, then the operation of the algorithm will be interrupted. r (int): maximum TT-rank for rank-adaptive ALS algorithm. If is None, then the TT-ALS with constant rank will be used (in the case of the constant rank, its value will be the same as the rank of the initial approximation Y0). r_add (int): maximum rank grow on one iteration for the rank-adaptive ALS algorithm. It is used only if r argument is not None. e_adap (float): convergence criterion for rank-adaptive TT-ALS algorithm (> 0). It is used only if r argument is not None. lamb (float): regularization parameter for least squares. w (np.ndarray): optional vector for weights of the input data (it should have a length equal to the number of elements in the data set). If this vector is used, then lamb parameter should be None. cb (function): optional callback function. It will be called after each sweep and the accuracy check with the arguments: Y, info and opts, where Y is the current approximation (TT-tensor), info is the info dictionary and the dictionary opts contains fields Yl, Yr and Yold. If the callback returns a true value, then the algorithm will be stopped (in the info dictionary, in this case, the stop type of the algorithm will be cb). swap_tol (int): experimental option. allow_swap (bool): experimental flag. allow_skip_cores (bool): if there is no data to learn all slices of the, TT-cores still work, keeping these slices. If the flag is not enabled, then in case of insufficient size of the training dataset, an error will be generated. use_stab (bool): if the flag is set, then the rank-adaptive method will use additional stabilization of the cores. log (bool): if flag is set, then the information about the progress of the algorithm will be printed after each sweep (and before the first sweep). Returns: list: TT-tensor, which represents the TT-approximation for the tensor. """ _time = tpc() assert r is None or update_sol is None, "Cannot update core of non-constant rank" info.update({'e': -1, 'e_vld': -1, 'nswp': 0, 'stop': None}) info['r'] = teneva.erank(Y0) I_trn = np.asanyarray(I_trn, dtype=int) y_trn = np.asanyarray(y_trn, dtype=float) m = I_trn.shape[0] d = I_trn.shape[1] if allow_swap: msg = 'The option "allow_swap" works only with adaptive rank' assert r is not None, msg I_trn = I_trn.copy() rearrange = np.arange(d) info['rearrange'] = rearrange print('!!! Note that "allow_swap" is a VERY experimental option') Y = teneva.copy(Y0) if r is not None: Y = teneva.orthogonalize(Y, 0, use_stab) if not allow_skip_cores: for k in range(d): if np.unique(I_trn[:, k]).size != Y[k].shape[1]: msg = 'One groundtruth sample is needed for every slice' raise ValueError(msg) info['e_vld'] = teneva.accuracy_on_data(Y, I_vld, y_vld) teneva._info_appr(info, _time, nswp, e, e_vld, log) Yl = [np.ones((m, Y[k].shape[0])) for k in range(d)] Yr = [np.ones((Y[k].shape[2], m)) for k in range(d)] for k in range(d-1, 0, -1): i = I_trn[:, k] Q = Y[k][:, i, :] contract('riq,qi->ri', Q, Yr[k], out=Yr[k-1]) while True: Yold = teneva.copy(Y) was_swap = False idx_cache = dict() for k in range(0, d-1 if r is None else d-2, +1): i = I_trn[:, k] if r is None: Y[k] = _optimize_core(Y[k], i, y_trn, Yl[k], Yr[k], lamb=lamb, w=w, update_sol=update_sol) contract('jk,kjl->jl', Yl[k], Y[k][:, i, :], out=Yl[k+1]) else: swaped = {} if allow_swap else None r_max = min(r, Y[k].shape[-1] + r_add) Y[k], Y[k+1] = _optimize_core_adaptive(Y[k], Y[k+1], i, I_trn[:, k+1], y_trn, Yl[k], Yr[k+1], e_adap, r_max, lamb, w, ltr=True, allow_swap=swaped, swap_tol=swap_tol, cache=idx_cache) idx_cache = dict(i1=idx_cache['i2']) if allow_swap and swaped.get('swapped', False): print(f'DEBUG | idxs: {k} <-> {k+1}') was_swap = True I_trn[:, [k, k+1]] = I_trn[:, [k+1, k]] i = I_trn[:, k] swap_two = np.arange(len(rearrange)) swap_two[k], swap_two[k+1] = swap_two[k+1], swap_two[k] rearrange[:] = swap_two[rearrange] idx_cache = {} Yl[k+1] = contract('jk,kjl->jl', Yl[k], Y[k][:, i, :]) idx_cache = dict() for k in range(d-1, 0 if r is None else 1, -1): i = I_trn[:, k] if r is None: Y[k] = _optimize_core(Y[k], i, y_trn, Yl[k], Yr[k], lamb=lamb, w=w, update_sol=update_sol) contract('ijk,kj->ij', Y[k][:, i, :], Yr[k], out=Yr[k-1]) else: swaped = {} if allow_swap else None r_max = min(r, Y[k-1].shape[-1] + r_add) Y[k-1], Y[k] = _optimize_core_adaptive(Y[k-1], Y[k], I_trn[:, k-1], i, y_trn, Yl[k-1], Yr[k], e_adap, r_max, lamb, w, ltr=False, allow_swap=swaped, swap_tol=swap_tol, cache=idx_cache) idx_cache = dict(i2=idx_cache['i1']) if allow_swap and swaped.get('swapped', False): print(f'DEBUG | idxs: {k} <-> {k-1}') was_swap = True I_trn[:, [k, k-1]] = I_trn[:, [k-1, k]] i = I_trn[:, k] swap_two = np.arange(len(rearrange)) swap_two[k], swap_two[k-1] = swap_two[k-1], swap_two[k] rearrange[:] = swap_two[rearrange] idx_cache = {} Yr[k-1] = contract('ijk,kj->ij', Y[k][:, i, :], Yr[k]) info['nswp'] += 1 info['r'] = teneva.erank(Y) info['e'] = 1.E+10 if was_swap else teneva.accuracy(Y, Yold) info['e_vld'] = teneva.accuracy_on_data( Y, I_vld[:, rearrange] if allow_swap else I_vld, y_vld) if cb: opts = {'Yold': Yold, 'Yl': Yl, 'Yr': Yr} if cb(Y, info, opts) is True: info['stop'] = info['stop'] or 'cb' if teneva._info_appr(info, _time, nswp, e, e_vld, log): return Y
def _lstsq(A, y, lamb=1e-2, w=None, *, overwrite_a=True, update_sol=None): if update_sol is not None: y = y - A@update_sol if lamb is not None: if w is not None: AW = w[:, None] * A AtA = A.T @ AW Aty = AW.T @ y else: AtA = A.T @ A Aty = A.T @ y return sp.linalg.lstsq(AtA + lamb * np.identity(A.shape[1]), Aty, overwrite_a=True, overwrite_b=True, lapack_driver='gelsy') else: if w is not None: A = w[:, None] * A y = y * w else: if not overwrite_a: A = np.copy(A) return sp.linalg.lstsq(A, y, overwrite_a=True, overwrite_b=True, lapack_driver='gelsy') def _optimize_core(Q, i, y_trn, Yl, Yr, lamb, w, update_sol=None): Q = Q.copy() for k in range(Q.shape[1]): idx = np.where(i == k)[0] if not idx.any(): continue lhs = Yr[:, idx].T[:, np.newaxis, :] rhs = Yl[idx, :][:, :, np.newaxis] A = (lhs * rhs).reshape(len(idx), -1) b = y_trn[idx] if update_sol is None: sol, residuals, rank, s = _lstsq(A, b, lamb=lamb, w=w[idx] if w is not None else None, update_sol=None) Q[:, k, :] = sol.reshape(Q[:, k, :].shape) else: sol, residuals, rank, s = _lstsq(A, b, lamb=lamb, w=w[idx] if w is not None else None, update_sol=Q[:, k, :].reshape(-1)) Q[:, k, :] += sol.reshape(Q[:, k, :].shape) if False and rank < A.shape[1]: # TODO: check print(f'ALS WRN | Bad cond in LSTSQ: {rank} < {A.shape[1]}') return Q def _optimize_core_adaptive(Q1, Q2, i1, i2, y_trn, Yl, Yr, e, r, lamb, w, ltr=True, allow_swap=None, swap_tol=3, cache=None): shape = Q1.shape[0], Q2.shape[2] shapeQ1 = Q1.shape[:2] shapeQ2 = Q2.shape[1:] Q = np.empty((Q1.shape[0], Q1.shape[1], Q2.shape[1], Q2.shape[2])) cache = {} if cache is None else cache try: i1_cache = cache['i1'] except KeyError: cache['i1'] = i1_cache = dict() for k1 in range(Q1.shape[1]): i1_cache[k1] = i1 == k1 try: i2_cache = cache['i2'] except KeyError: cache['i2'] = i2_cache = dict() for k2 in range(Q2.shape[1]): i2_cache[k2] = i2 == k2 for k1 in range(Q1.shape[1]): for k2 in range(Q2.shape[1]): idx = cache['i1'][k1] & cache['i2'][k2] if not idx.any(): continue lhs = Yr[:, idx].T[:, np.newaxis, :] rhs = Yl[idx, :][:, :, np.newaxis] A = (lhs * rhs).reshape(idx.sum(), -1) b = y_trn[idx] sol, residuals, rank, s = _lstsq(A, b, lamb=lamb, w=w[idx] if w is not None else None) Q[:, k1, k2, :] = sol.reshape(shape) if False and rank < A.shape[1]: # TODO: check print(f'ALS WRN | Bad cond in LSTSQ: {rank} < {A.shape[1]}') Qs = Q.reshape(np.prod(Q.shape[:2]), -1) V1, V2 = teneva.matrix_skeleton(Qs, e, r, rel=True, give_to='r' if ltr else 'l') rank1 = V1.shape[-1] if allow_swap is not None: Q = np.transpose(Q, [0, 2, 1, 3]) Qsr = Q.reshape(np.prod(Q.shape[:2]), -1) V1r, V2r = teneva.matrix_skeleton(Qsr, e, r, rel=True, give_to='r' if ltr else 'l') rank2 = V1r.shape[-1] qual1 = _quality_of_decomp(Qs, V1, V2) qual2 = _quality_of_decomp(Qsr, V1r, V2r) * swap_tol if rank2 < rank1 or qual1 > qual2: print(f'DEBUG | ranks: {rank2} < {rank1}, swapping', end=' ') allow_swap['swapped'] = True V1 = V1r V2 = V2r shapeQ1 = (Q1.shape[0], Q2.shape[1]) shapeQ2 = (Q1.shape[1], Q2.shape[-1]) return V1.reshape(*shapeQ1, -1), V2.reshape(-1, *shapeQ2) def _quality_of_decomp(Q, V1, V2): return np.linalg.norm(V1 @ V2 - Q) / np.linalg.norm(Q)