Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensorcircuit/cloud/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def set_token(
if cached:
file_token = {k: b64encode_s(v) for k, v in saved_token.items()}
if file_token:
with open(authpath, "w") as f:
fd = os.open(authpath, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w") as f:
os.chmod(authpath, 0o600)
json.dump(file_token, f)

return saved_token
Expand Down
71 changes: 68 additions & 3 deletions tests/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import time
import pytest
import numpy as np
import stat
import json
from unittest.mock import patch

thisfile = os.path.abspath(__file__)
modulepath = os.path.dirname(os.path.dirname(thisfile))
Expand All @@ -12,25 +15,26 @@
from tensorcircuit.cloud import apis, wrapper
from tensorcircuit.results import counts

if "TC_CLOUD_TEST" not in os.environ:
pytest.skip(allow_module_level=True)
# skip on CI due to no token
skip_cloud = pytest.mark.skipif("TC_CLOUD_TEST" not in os.environ, reason="no token")
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you better to organize the test in this file into two classes, one with skip_cloud, the other one without?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I have reorganized tests/test_cloud.py into two classes: TestCloud (decorated with skip_cloud) and TestCloudAuth (not skipped). This allows the authentication unit test to run independently of the integration tests.



@skip_cloud
def test_get_token():
print(apis.get_token(provider="Tencent"))
p = apis.get_provider("tencent")
print(p.get_token())
print(p.get_device("simulator:tc").get_token())


@skip_cloud
def test_list_devices():
print(apis.list_devices())
p = apis.get_provider()
print(p.list_devices())
print(p.list_devices(state="on"))


@skip_cloud
def test_get_device():
d1 = apis.get_device(device="tencent::hello")
assert d1.name == "hello"
Expand All @@ -47,6 +51,7 @@ def test_get_device():
assert d4.provider.name == "tencent"


@skip_cloud
def test_get_device_cache():
d1 = apis.get_device("local::testing")
d2 = apis.get_device(provider="local", device="testing")
Expand All @@ -60,6 +65,7 @@ def test_get_device_cache():
assert id(d4) != id(d1)


@skip_cloud
def test_list_properties():
d = apis.get_device(device="simulator:aer")
print(d.list_properties())
Expand All @@ -68,6 +74,7 @@ def test_list_properties():
apis.list_properties(device="hell")


@skip_cloud
def test_submit_task():
c = tc.Circuit(3)
c.H(0)
Expand All @@ -80,6 +87,7 @@ def test_submit_task():
assert t.get_logical_physical_mapping() == {0: 0, 1: 1, 2: 2}


@skip_cloud
def test_resubmit_task():
c = tc.Circuit(3)
c.H(0)
Expand All @@ -91,6 +99,7 @@ def test_resubmit_task():
print(t1.details(wait=True))


@skip_cloud
def test_get_task():
apis.set_device("simulator:tcn1")
c = tc.Circuit(2)
Expand All @@ -104,17 +113,20 @@ def test_get_task():
apis.set_device()


@skip_cloud
def test_list_tasks():
d = apis.get_device(device="simulator:aer")
print(d.list_tasks())
print(apis.list_tasks(device="simulator:tc"))


@skip_cloud
def test_local_list_device():
dd = apis.list_devices(provider="local")
assert dd[0].name == "testing"


@skip_cloud
def test_local_submit_task():
c = tc.Circuit(2)
c.h(0)
Expand All @@ -127,10 +139,12 @@ def test_local_submit_task():
print(t.get_device())


@skip_cloud
def test_local_list_tasks():
print(apis.list_tasks(provider="local"))


@skip_cloud
def test_local_batch_submit():
apis.set_provider("local")
c = tc.Circuit(2)
Expand All @@ -147,6 +161,7 @@ def test_local_batch_submit():
apis.set_provider("tencent")


@skip_cloud
def test_batch_exp_ps():
pss = [[1, 0], [0, 3]]
c = tc.Circuit(2)
Expand All @@ -168,6 +183,7 @@ def test_batch_exp_ps():
)


@skip_cloud
def test_batch_submit_template():
run = tc.cloud.wrapper.batch_submit_template(
device="simulator:tc", batch_limit=2, prior=10
Expand All @@ -184,6 +200,7 @@ def test_batch_submit_template():
assert len(rs) == 4


@skip_cloud
def test_allz_batch():
n = 5

Expand Down Expand Up @@ -217,3 +234,51 @@ def qmlf(inputs, params, device=None):
params = np.ones([2, n])
print(qmlf(inputs, params, device="9gmon"))
print(qmlf(inputs, params))


def test_token_file_permissions(tmp_path):
# Mock os.path.expanduser to return tmp_path
# We patch where it is used in tensorcircuit.cloud.apis
with patch("tensorcircuit.cloud.apis.os.path.expanduser", return_value=str(tmp_path)):
authpath = tmp_path / ".tc.auth.json"

# Ensure clean state for saved_token global variable if necessary
# set_token with clear=True clears the global saved_token
apis.set_token(clear=True)

# Scenario 1: File creation (new file)
# Set a dummy token
apis.set_token(token="dummy_token_1", provider="tencent", cached=True)

# Verify file exists
assert authpath.exists()

# Verify permissions (only on POSIX systems where these bits are meaningful)
if os.name == "posix":
st = os.stat(authpath)
# Check that group and others have no permissions (should be 0)
assert (st.st_mode & (stat.S_IRGRP | stat.S_IWGRP | stat.S_IXGRP |
stat.S_IROTH | stat.S_IWOTH | stat.S_IXOTH)) == 0

# Scenario 2: File update (existing file with insecure permissions)
if os.name == "posix":
# Manually set insecure permissions to simulate an existing insecure file
os.chmod(authpath, 0o666)
st_before = os.stat(authpath)
# Verify it is indeed insecure (readable by others)
assert (st_before.st_mode & stat.S_IROTH)

# Update token (add another token or update existing)
# This triggers the write logic again
apis.set_token(token="dummy_token_2", provider="local", cached=True)

# Verify permissions again (should be fixed to 0600)
if os.name == "posix":
st_after = os.stat(authpath)
assert (st_after.st_mode & (stat.S_IRGRP | stat.S_IWGRP | stat.S_IXGRP |
stat.S_IROTH | stat.S_IWOTH | stat.S_IXOTH)) == 0

# Verify content updated and readable
with open(authpath, "r") as f:
data = json.load(f)
assert data is not None
Loading