torch.lu_solve
-
torch.lu_solve(b, LU_data, LU_pivots, *, out=None) → Tensor
-
Returns the LU solve of the linear system using the partially pivoted LU factorization of A from
torch.lu()
.This function supports
float
,double
,cfloat
andcdouble
dtypes forinput
.- Parameters
-
- b (Tensor) – the RHS tensor of size , where is zero or more batch dimensions.
-
LU_data (Tensor) – the pivoted LU factorization of A from
torch.lu()
of size , where is zero or more batch dimensions. -
LU_pivots (IntTensor) – the pivots of the LU factorization from
torch.lu()
of size , where is zero or more batch dimensions. The batch dimensions ofLU_pivots
must be equal to the batch dimensions ofLU_data
.
- Keyword Arguments
-
out (Tensor, optional) – the output tensor.
Example:
>>> A = torch.randn(2, 3, 3) >>> b = torch.randn(2, 3, 1) >>> A_LU = torch.lu(A) >>> x = torch.lu_solve(b, *A_LU) >>> torch.norm(torch.bmm(A, x) - b) tensor(1.00000e-07 * 2.8312)
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.lu_solve.html