Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Release history

## trio-websocket x.y.z
### Fixed
- fix the client hanging upon a certificate issue
([#199](https://github.com/python-trio/trio-websocket/issues/199))

## trio-websocket 0.12.2 (2025-02-24)
### Fixed
- fix incorrect port when using a `wss://` URL without supplying an explicit
Expand Down
103 changes: 103 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from __future__ import annotations

import copy
import datetime
import re
import ssl
import sys
Expand Down Expand Up @@ -276,6 +277,108 @@ async def test_serve_ssl(nursery: trio.Nursery) -> None:
assert conn.remote.is_ssl


async def test_serve_ssl_wrong_ca(nursery: trio.Nursery) -> None:
server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
client_context = ssl.create_default_context()
ca = trustme.CA()
other_ca = trustme.CA()
other_ca.configure_trust(client_context)
cert = ca.issue_server_cert(HOST)
cert.configure_cert(server_context)

server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0,
server_context)
assert isinstance(server, WebSocketServer)
port = server.port
with trio.fail_after(0.1):
with pytest.raises(HandshakeError) as excinfo:
async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context
) as conn:
assert not conn.closed
assert isinstance(conn.local, Endpoint)
assert conn.local.is_ssl
assert isinstance(conn.remote, Endpoint)
assert conn.remote.is_ssl
assert isinstance(excinfo.value.__cause__, ssl.SSLError)
assert excinfo.value.__cause__.reason == "CERTIFICATE_VERIFY_FAILED"


async def test_ssl_client_cert(nursery: trio.Nursery) -> None:

ca = trustme.CA()

server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
cert = ca.issue_server_cert(HOST)
cert.configure_cert(server_context)
server_context.verify_mode = ssl.CERT_REQUIRED
ca.configure_trust(server_context)

# Setup a valid client certificate.
good_client_context = ssl.create_default_context()
ca.configure_trust(good_client_context)
good_client_cert = ca.issue_cert("user@example.org")
good_client_cert.configure_cert(good_client_context)

# Setup an expired client certificate.
bad_client_context = ssl.create_default_context()
ca.configure_trust(bad_client_context)
bad_client_cert = ca.issue_cert(
"user@example.org", not_after=datetime.datetime.now(datetime.UTC))
bad_client_cert.configure_cert(bad_client_context)

# Use the timeout because the old SSL code made the client hang.
with trio.fail_after(0.5):
async with trio.open_nursery() as nurs:

server = await nurs.start(
serve_websocket, echo_request_handler, HOST, 0, server_context)
assert isinstance(server, WebSocketServer)
port = server.port

# Test with the valid certificate.
async with open_websocket(
HOST, port, RESOURCE, use_ssl=good_client_context
) as conn:
assert not conn.closed
assert isinstance(conn.local, Endpoint)
assert conn.local.is_ssl
assert isinstance(conn.remote, Endpoint)
assert conn.remote.is_ssl
await conn.send_message('foo')
assert await conn.get_message() == 'foo'

# Test with the expired certificate.
with pytest.raises(HandshakeError) as excinfo:
async with open_websocket(
HOST, port, RESOURCE, use_ssl=bad_client_context
) as conn:
assert not conn.closed
assert isinstance(conn.local, Endpoint)
assert conn.local.is_ssl
assert isinstance(conn.remote, Endpoint)
assert conn.remote.is_ssl
assert isinstance(excinfo.value.__cause__, ssl.SSLError)
assert excinfo.value.__cause__.reason == "SSLV3_ALERT_CERTIFICATE_EXPIRED"

# Test with the valid certificate again. If this does work now,
# this means that the expired certificate crashed the server.
try:
async with open_websocket(
HOST, port, RESOURCE, use_ssl=good_client_context
) as conn:
assert not conn.closed
assert isinstance(conn.local, Endpoint)
assert conn.local.is_ssl
assert isinstance(conn.remote, Endpoint)
assert conn.remote.is_ssl
await conn.send_message('foo')
assert await conn.get_message() == 'foo'
except:
raise RuntimeError("The server crashed in the first subtest") from None

nurs.cancel_scope.cancel()


async def test_serve_handler_nursery(nursery: trio.Nursery) -> None:
async with trio.open_nursery() as handler_nursery:
serve_with_nursery = partial(serve_websocket, echo_request_handler,
Expand Down
25 changes: 22 additions & 3 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,8 +1570,11 @@ async def _reader_task(self) -> None:
# Get network data.
try:
data = await self._stream.receive_some(self._receive_buffer_size)
except (trio.BrokenResourceError, trio.ClosedResourceError):
except (trio.BrokenResourceError, trio.ClosedResourceError) as exc:
await self._abort_web_socket()
# Wrap SSL errors into a HandshakeError.
if isinstance(exc.__cause__, ssl.SSLError):
raise HandshakeError() from exc.__cause__
break
if len(data) == 0:
logger.debug('%s received zero bytes (connection closed)',
Expand Down Expand Up @@ -1608,8 +1611,11 @@ async def _send(self, event: wsproto.events.Event) -> None:
logger.debug('%s sending %d bytes', self, len(data))
try:
await self._stream.send_all(data)
except (trio.BrokenResourceError, trio.ClosedResourceError):
except (trio.BrokenResourceError, trio.ClosedResourceError) as exc:
await self._abort_web_socket()
# Wrap SSL errors into a HandshakeError.
if isinstance(exc.__cause__, ssl.SSLError):
raise HandshakeError() from exc.__cause__
assert self._close_reason is not None
raise ConnectionClosed(self._close_reason) from None

Expand Down Expand Up @@ -1783,13 +1789,26 @@ async def _handle_connection(self, stream: trio.abc.Stream) -> None:
:param stream:
:type stream: trio.abc.Stream
'''

# Filter out "HandshakeError"s caused by "SSLError"s as we don't want a
# connection error to crash the server.
async def _reader_task():
try:
await connection._reader_task()
except* HandshakeError as excs:
non_ssl_errs = excs.subgroup(
lambda e: not isinstance(e, ExceptionGroup)
and not isinstance(e.__cause__, ssl.SSLError))
if non_ssl_errs:
raise non_ssl_errs

async with trio.open_nursery() as nursery:
connection = WebSocketConnection(stream,
WSConnection(ConnectionType.SERVER),
message_queue_size=self._message_queue_size,
max_message_size=self._max_message_size,
receive_buffer_size=self._receive_buffer_size)
nursery.start_soon(connection._reader_task)
nursery.start_soon(_reader_task)
with trio.move_on_after(self._connect_timeout) as connect_scope:
request = await connection._get_request()
if connect_scope.cancelled_caught:
Expand Down
Loading