Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/workos/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any, Dict, Optional, Type, cast, overload
from urllib.parse import quote

import httpx

Expand Down Expand Up @@ -53,8 +54,15 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
self._api_key = api_key or os.environ.get("WORKOS_API_KEY")
self._is_public = is_public
# Public clients (PKCE / browser / mobile / CLI) must never attach
# an API key, even if WORKOS_API_KEY is present in the environment.
if is_public:
self._api_key: Optional[str] = None
else:
self._api_key = api_key or os.environ.get("WORKOS_API_KEY")
self.client_id = client_id or os.environ.get("WORKOS_CLIENT_ID")
if not self._api_key and not self.client_id:
raise ValueError(
Expand Down Expand Up @@ -128,6 +136,21 @@ def _resolve_base_url(self, request_options: Optional[RequestOptions]) -> str:
return str(base_url).rstrip("/")
return self._base_url.rstrip("/")

@staticmethod
def _encode_path(path: str) -> str:
"""Percent-encode each path segment to prevent path-traversal/injection.

Splits on ``/`` and applies ``urllib.parse.quote(seg, safe='')`` to each
segment so that user-supplied IDs containing reserved characters (``/``,
``?``, ``#``, ``%``, etc.) cannot escape their intended segment. The
leading slash (if any) is preserved.
"""
if not path:
return path
leading = "/" if path.startswith("/") else ""
body = path[1:] if leading else path
return leading + "/".join(quote(seg, safe="") for seg in body.split("/"))
Comment on lines +139 to +152
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _encode_path cannot protect against / in user-supplied IDs

_encode_path is called on the fully-constructed path string after f-string interpolation has already merged static slashes with any slashes present in user-supplied values. For a call like path=f"vault/v1/kv/{object_id}" where object_id = "id/../../../auth/admin", the path reaches _encode_path as the flat string "vault/v1/kv/id/../../../auth/admin". The split-on-/+encode loop sees each segment individually (including the .. segments, which quote(seg, safe="") preserves because . is an RFC 3986 unreserved character), and reconstructs the same traversal path unchanged.

The docstring's claim that user-supplied IDs "cannot escape their intended segment with /" is incorrect for this calling pattern — only ?, #, and % injection are actually blocked. The fix would need to encode user-supplied ID components before interpolating them into the path string (or accept path components as separate arguments).


def _resolve_timeout(self, request_options: Optional[RequestOptions]) -> float:
timeout = self._request_timeout
if request_options:
Expand Down Expand Up @@ -332,6 +355,7 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
"""Initialize the WorkOS client.

Expand All @@ -342,6 +366,10 @@ def __init__(
request_timeout: HTTP request timeout in seconds. Falls back to WORKOS_REQUEST_TIMEOUT or 60.
jwt_leeway: JWT clock skew leeway in seconds.
max_retries: Maximum number of retries for failed requests. Defaults to 3.
is_public: When True, mark this client as public (PKCE / browser
/ mobile / CLI). The API key is forced to None and the
``WORKOS_API_KEY`` environment variable is ignored. Use
``create_public_client`` instead of setting this directly.

Raises:
ValueError: If neither api_key nor client_id is provided, directly or via environment variables.
Expand All @@ -353,6 +381,7 @@ def __init__(
request_timeout=request_timeout,
jwt_leeway=jwt_leeway,
max_retries=max_retries,
is_public=is_public,
)
self._client = httpx.Client(
timeout=self._request_timeout, follow_redirects=True
Expand Down Expand Up @@ -406,7 +435,7 @@ def request(
request_options: Optional[RequestOptions] = None,
) -> Any:
"""Make an HTTP request with retry logic."""
url = f"{self._resolve_base_url(request_options)}/{path}"
url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path).lstrip('/')}"
headers = self._build_headers(method, idempotency_key, request_options)
timeout = self._resolve_timeout(request_options)
max_retries = self._resolve_max_retries(request_options)
Expand Down Expand Up @@ -557,6 +586,7 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
"""Initialize the async WorkOS client.

Expand All @@ -578,6 +608,7 @@ def __init__(
request_timeout=request_timeout,
jwt_leeway=jwt_leeway,
max_retries=max_retries,
is_public=is_public,
)
self._client = httpx.AsyncClient(
timeout=self._request_timeout, follow_redirects=True
Expand Down Expand Up @@ -631,7 +662,7 @@ async def request(
request_options: Optional[RequestOptions] = None,
) -> Any:
"""Make an async HTTP request with retry logic."""
url = f"{self._resolve_base_url(request_options)}/{path}"
url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path).lstrip('/')}"
headers = self._build_headers(method, idempotency_key, request_options)
timeout = self._resolve_timeout(request_options)
max_retries = self._resolve_max_retries(request_options)
Expand Down
2 changes: 1 addition & 1 deletion src/workos/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _verify_signature(
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > tolerance:
if abs(seconds_since_issued) > tolerance:
raise ValueError("Timestamp outside the tolerance zone")

body_str = payload.decode("utf-8") if isinstance(payload, bytes) else payload
Expand Down
2 changes: 2 additions & 0 deletions src/workos/public_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def create_public_client(
from ._client import WorkOSClient

return WorkOSClient(
api_key=None,
client_id=client_id,
base_url=base_url,
request_timeout=request_timeout,
is_public=True,
)
45 changes: 43 additions & 2 deletions src/workos/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
from cryptography.fernet import Fernet
from jwt import PyJWKClient

from ._errors import (
AuthenticationError,
EmailVerificationRequiredError,
MfaChallengeError,
OrganizationSelectionRequiredError,
SsoRequiredError,
WorkOSConnectionError,
WorkOSTimeoutError,
)

if TYPE_CHECKING:
from ._client import AsyncWorkOSClient, WorkOSClient

Expand All @@ -37,6 +47,37 @@ class AuthenticateWithSessionCookieFailureReason(Enum):
INVALID_JWT = "invalid_jwt"
INVALID_SESSION_COOKIE = "invalid_session_cookie"
NO_SESSION_COOKIE_PROVIDED = "no_session_cookie_provided"
MFA_CHALLENGE_REQUIRED = "mfa_challenge_required"
SSO_REQUIRED = "sso_required"
EMAIL_VERIFICATION_REQUIRED = "email_verification_required"
ORGANIZATION_SELECTION_REQUIRED = "organization_selection_required"
REFRESH_DENIED = "refresh_denied"
REFRESH_NETWORK_ERROR = "refresh_network_error"


def _map_refresh_exception_to_reason(
exc: Exception,
) -> Union[AuthenticateWithSessionCookieFailureReason, str]:
"""Map an exception raised by a refresh request to a structured reason.

Falls back to ``str(exc)`` for unknown errors so callers retain the
pre-existing string form for diagnostics.
"""
if isinstance(exc, MfaChallengeError):
return AuthenticateWithSessionCookieFailureReason.MFA_CHALLENGE_REQUIRED
if isinstance(exc, SsoRequiredError):
return AuthenticateWithSessionCookieFailureReason.SSO_REQUIRED
if isinstance(exc, EmailVerificationRequiredError):
return AuthenticateWithSessionCookieFailureReason.EMAIL_VERIFICATION_REQUIRED
if isinstance(exc, OrganizationSelectionRequiredError):
return (
AuthenticateWithSessionCookieFailureReason.ORGANIZATION_SELECTION_REQUIRED
)
if isinstance(exc, AuthenticationError):
return AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED
if isinstance(exc, (WorkOSConnectionError, WorkOSTimeoutError)):
return AuthenticateWithSessionCookieFailureReason.REFRESH_NETWORK_ERROR
return str(exc)


@dataclass(slots=True)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Expand Down Expand Up @@ -328,7 +369,7 @@ def refresh(
)
except Exception as e:
return RefreshWithSessionCookieErrorResponse(
authenticated=False, reason=str(e)
authenticated=False, reason=_map_refresh_exception_to_reason(e)
)

def get_logout_url(self, return_to: Optional[str] = None) -> str:
Expand Down Expand Up @@ -507,7 +548,7 @@ async def refresh(
)
except Exception as e:
return RefreshWithSessionCookieErrorResponse(
authenticated=False, reason=str(e)
authenticated=False, reason=_map_refresh_exception_to_reason(e)
)

async def get_logout_url(self, return_to: Optional[str] = None) -> str:
Expand Down
24 changes: 17 additions & 7 deletions src/workos/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _aes_gcm_encrypt(
encryptor = Cipher(
algorithms.AES(key), modes.GCM(iv), backend=default_backend()
).encryptor()
if aad:
if aad is not None:
encryptor.authenticate_additional_data(aad)
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return {"ciphertext": ciphertext, "iv": iv, "tag": encryptor.tag}
Expand All @@ -256,7 +256,7 @@ def _aes_gcm_decrypt(
decryptor = Cipher(
algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend()
).decryptor()
if aad:
if aad is not None:
decryptor.authenticate_additional_data(aad)
return decryptor.update(ciphertext) + decryptor.finalize()

Expand All @@ -282,10 +282,12 @@ def _decode_u32_leb128(buf: bytes) -> Tuple[int, int]:
res = 0
bit = 0
for i, b in enumerate(buf):
if i > 4:
if i >= 4 and (b & 0x80) != 0:
raise ValueError("LEB128 integer overflow (was more than 4 bytes)")
res |= (b & 0x7F) << (7 * bit)
if (b & 0x80) == 0:
if res > 0xFFFFFFFF:
raise ValueError("LEB128 integer overflow (exceeds 32 bits)")
return res, i + 1
bit += 1
raise ValueError("LEB128 integer not found")
Expand Down Expand Up @@ -468,7 +470,9 @@ def encrypt(
key = base64.b64decode(key_pair.data_key.key)
key_blob = base64.b64decode(key_pair.encrypted_keys)
prefix_len_buffer = _encode_u32_leb128(len(key_blob))
aad_buffer = associated_data.encode("utf-8") if associated_data else None
aad_buffer = (
associated_data.encode("utf-8") if associated_data is not None else None
)
iv = os.urandom(12)

result = _aes_gcm_encrypt(data.encode("utf-8"), key, iv, aad_buffer)
Expand All @@ -490,7 +494,9 @@ def decrypt(
data_key = self.decrypt_data_key(keys=decoded.keys)

key = base64.b64decode(data_key.key)
aad_buffer = associated_data.encode("utf-8") if associated_data else None
aad_buffer = (
associated_data.encode("utf-8") if associated_data is not None else None
)

decrypted_bytes = _aes_gcm_decrypt(
ciphertext=decoded.ciphertext,
Expand Down Expand Up @@ -647,7 +653,9 @@ async def encrypt(
key = base64.b64decode(key_pair.data_key.key)
key_blob = base64.b64decode(key_pair.encrypted_keys)
prefix_len_buffer = _encode_u32_leb128(len(key_blob))
aad_buffer = associated_data.encode("utf-8") if associated_data else None
aad_buffer = (
associated_data.encode("utf-8") if associated_data is not None else None
)
iv = os.urandom(12)

result = _aes_gcm_encrypt(data.encode("utf-8"), key, iv, aad_buffer)
Expand All @@ -668,7 +676,9 @@ async def decrypt(
data_key = await self.decrypt_data_key(keys=decoded.keys)

key = base64.b64decode(data_key.key)
aad_buffer = associated_data.encode("utf-8") if associated_data else None
aad_buffer = (
associated_data.encode("utf-8") if associated_data is not None else None
)

decrypted_bytes = _aes_gcm_decrypt(
ciphertext=decoded.ciphertext,
Expand Down
4 changes: 2 additions & 2 deletions src/workos/webhooks/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def verify_header(
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > max_seconds_since_issued:
if abs(seconds_since_issued) > max_seconds_since_issued:
raise ValueError("Timestamp outside the tolerance zone")

body_str = (
Expand Down Expand Up @@ -520,7 +520,7 @@ def verify_header(
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > max_seconds_since_issued:
if abs(seconds_since_issued) > max_seconds_since_issued:
raise ValueError("Timestamp outside the tolerance zone")

body_str = (
Expand Down
4 changes: 2 additions & 2 deletions src/workos/webhooks/_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def verify_header(

issued_timestamp = issued_timestamp[2:]
signature_hash = signature_hash[3:]
max_seconds_since_issued = tolerance or DEFAULT_TOLERANCE
max_seconds_since_issued = tolerance if tolerance is not None else DEFAULT_TOLERANCE
current_time = time.time()
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > max_seconds_since_issued:
if abs(seconds_since_issued) > max_seconds_since_issued:
raise ValueError("Timestamp outside the tolerance zone")

unhashed_string = "{0}.{1}".format(issued_timestamp, event_body.decode("utf-8"))
Expand Down
11 changes: 11 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def test_verify_header_stale_timestamp(self):
tolerance=30,
)

def test_verify_header_future_timestamp(self):
future_ts = int((time.time() + 60) * 1000)
sig = _make_sig_header(SAMPLE_ACTION_PAYLOAD, SECRET, future_ts)
with pytest.raises(ValueError, match="tolerance zone"):
self.actions.verify_header(
payload=SAMPLE_ACTION_PAYLOAD,
sig_header=sig,
secret=SECRET,
tolerance=30,
)

def test_verify_header_custom_tolerance(self):
old_ts = int((time.time() - 10) * 1000)
sig = _make_sig_header(SAMPLE_ACTION_PAYLOAD, SECRET, old_ts)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_webhook_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ def test_verify_header_stale_timestamp(self, workos):
tolerance=180,
)

def test_verify_header_future_timestamp(self, workos):
future_ts = int((time.time() + 300) * 1000)
sig = _make_sig_header(SAMPLE_EVENT, SECRET, future_ts)
with pytest.raises(ValueError, match="tolerance zone"):
workos.webhooks.verify_header(
event_body=SAMPLE_EVENT,
event_signature=sig,
secret=SECRET,
tolerance=180,
)


class TestStandaloneVerifyEvent:
def test_standalone_verify_event(self):
Expand Down Expand Up @@ -157,3 +168,14 @@ def test_standalone_verify_header_invalid(self):
event_signature=sig,
secret=SECRET,
)

def test_standalone_verify_header_future_timestamp(self):
future_ts = int((time.time() + 300) * 1000)
sig = _make_sig_header(SAMPLE_EVENT, SECRET, future_ts)
with pytest.raises(ValueError, match="tolerance zone"):
standalone_verify_header(
event_body=SAMPLE_EVENT.encode("utf-8"),
event_signature=sig,
secret=SECRET,
tolerance=180,
)
Loading