torch.lu_unpack
-
torch.lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True)
[source] -
Unpacks the data and pivots from a LU factorization of a tensor.
Returns a tuple of tensors as
(the pivots, the L tensor, the U tensor)
.- Parameters
Examples:
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = A.lu() >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots) >>> >>> # can recover A from factorization >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) >>> # LU factorization of a rectangular matrix: >>> A = torch.randn(2, 3, 2) >>> A_LU, pivots = A.lu() >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots) >>> P tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], [[0., 0., 1.], [0., 1., 0.], [1., 0., 0.]]]) >>> A_L tensor([[[ 1.0000, 0.0000], [ 0.4763, 1.0000], [ 0.3683, 0.1135]], [[ 1.0000, 0.0000], [ 0.2957, 1.0000], [-0.9668, -0.3335]]]) >>> A_U tensor([[[ 2.1962, 1.0881], [ 0.0000, -0.8681]], [[-1.0947, 0.3736], [ 0.0000, 0.5718]]]) >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) >>> torch.norm(A_ - A) tensor(2.9802e-08)
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.lu_unpack.html