Skip to content

Commit

Permalink
fix example code
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick-Star125 committed Jun 6, 2023
1 parent 7f5219c commit e3ca1d5
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions python/paddle/sparse/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,12 +963,9 @@ def pca_lowrank(x, q=None, center=True, niter=2, name=None):
else:
sparse_x = dense_x.to_sparse_csr()
cuda_version = paddle.version.cuda()
if cuda_version is None or cuda_version == 'False' or int(cuda_version.split('.')[0]) < 11:
print("sparse.pca_lowrank API only support CUDA 11.x")
U, S, V = None, None, None
else:
U, S, V = paddle.sparse.pca_lowrank(sparse_x)
print("sparse.pca_lowrank API only support CUDA 11.x")
U, S, V = None, None, None
# U, S, V = pca_lowrank(sparse_x)
print(U)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
Expand Down Expand Up @@ -1087,7 +1084,11 @@ def svd_lowrank(x, q=6, niter=2, M=None):
raise ValueError('Input must be sparse, but got dense')

cuda_version = paddle.version.cuda()
if cuda_version is None or cuda_version == 'False' or int(cuda_version.split('.')[0]) < 11:
if (
cuda_version is None
or cuda_version == 'False'
or int(cuda_version.split('.')[0]) < 11
):
raise ValueError('sparse.pca_lowrank API only support CUDA 11.x')

(m, n) = x.shape[-2:]
Expand Down

0 comments on commit e3ca1d5

Please sign in to comment.