diff --git a/changelog.md b/changelog.md index c51d4e9c..360e5af8 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ Features * Let the `--dsn` argument accept literal DSNs as well as aliases. * Accept `--character-set` as an alias for `--charset` at the CLI. * Add SSL/TLS version to `status` output. +* More liberally accept `on`/`off` values for `true`/`false`, and vice versa. Bug Fixes diff --git a/mycli/config.py b/mycli/config.py index a79b1021..17fe4541 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -305,6 +305,13 @@ def str_to_bool(s: str | bool) -> bool: raise ValueError(f'not a recognized boolean value: {s}') +def str_to_on_off(s: str | bool) -> str: + bool_str = str(str_to_bool(s)) + if bool_str == 'True': + return 'on' + return 'off' + + def strip_matching_quotes(s: str) -> str: """Remove matching, surrounding quotes from a string. diff --git a/mycli/main.py b/mycli/main.py index a3899fe0..f647462a 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -56,7 +56,15 @@ from mycli.clitoolbar import create_toolbar_tokens_func from mycli.compat import WIN from mycli.completion_refresher import CompletionRefresher -from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, strip_matching_quotes, write_default_config +from mycli.config import ( + get_mylogin_cnf_path, + open_mylogin_cnf, + read_config_files, + str_to_bool, + str_to_on_off, + strip_matching_quotes, + write_default_config, +) from mycli.key_bindings import mycli_bindings from mycli.lexer import MyCliLexer from mycli.packages import special @@ -220,7 +228,9 @@ def __init__( # set ssl_mode if a valid option is provided in a config file, otherwise None ssl_mode = c["main"].get("ssl_mode", None) or c["connection"].get("default_ssl_mode", None) - if ssl_mode not in ("auto", "on", "off", None): + if ssl_mode is None: + self.ssl_mode = ssl_mode + elif ssl_mode.lower() not in ("auto", "on", "off", "1", "0", "true", "false"): self.echo(f"Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.", err=True, fg="red") self.ssl_mode = None else: @@ -1659,7 +1669,7 @@ def get_last_query(self) -> str | None: "--ssl-mode", "ssl_mode", help="Set desired SSL behavior. auto=preferred, on=required, off=off.", - type=click.Choice(["auto", "on", "off"]), + type=str, ) @click.option("--ssl/--no-ssl", "ssl_enable", default=None, help="Enable SSL for connection (automatically enabled with other flags).") @click.option("--ssl-ca", help="CA file in PEM format.", type=click.Path(exists=True)) @@ -1995,6 +2005,14 @@ def get_password_from_file(password_file: str | None) -> str | None: ssl_enable = True ssl_mode = ssl_mode or mycli.ssl_mode # cli option or config option + if ssl_mode: + ssl_mode = ssl_mode.lower() + if ssl_mode and ssl_mode != 'auto': + try: + ssl_mode = str_to_on_off(ssl_mode) + except ValueError: + click.secho('Unknown value for ssl_mode', err=True, fg='red') + sys.exit(1) # if there is a mismatch between the ssl_mode value and other sources of ssl config, show a warning # specifically using "is False" to not pickup the case where ssl_enable is None (not set by the user) diff --git a/test/test_config.py b/test/test_config.py index 5bb0ab4f..fda9bc4e 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -16,6 +16,7 @@ read_and_decrypt_mylogin_cnf, read_config_file, str_to_bool, + str_to_on_off, strip_matching_quotes, ) @@ -149,6 +150,25 @@ def test_str_to_bool(): str_to_bool(None) +def test_str_to_on_off(): + assert str_to_on_off(False) == 'off' + assert str_to_on_off(True) == 'on' + assert str_to_on_off("False") == 'off' + assert str_to_on_off("True") == 'on' + assert str_to_on_off("TRUE") == 'on' + assert str_to_on_off("1") == 'on' + assert str_to_on_off("0") == 'off' + assert str_to_on_off("on") == 'on' + assert str_to_on_off("off") == 'off' + assert str_to_on_off("off") == 'off' + + with pytest.raises(ValueError): + str_to_on_off("foo") + + with pytest.raises(TypeError): + str_to_on_off(None) + + def test_read_config_file_list_values_default(): """Test that reading a config file uses list_values by default.""" diff --git a/test/test_main.py b/test/test_main.py index 1415f598..1fbd3e51 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -154,6 +154,17 @@ def test_ssl_mode_on(executor, capsys): assert ssl_cipher +@dbtest +def test_ssl_mode_true(executor, capsys): + runner = CliRunner() + ssl_mode = 'true' + sql = 'select * from performance_schema.session_status where variable_name = "Ssl_cipher"' + result = runner.invoke(cli, args=CLI_ARGS + ['--csv', '--ssl-mode', ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split('\n'))) + ssl_cipher = result_dict.get('VARIABLE_VALUE', None) + assert ssl_cipher + + @dbtest def test_ssl_mode_auto(executor, capsys): runner = CliRunner() @@ -176,6 +187,17 @@ def test_ssl_mode_off(executor, capsys): assert not ssl_cipher +@dbtest +def test_ssl_mode_false(executor, capsys): + runner = CliRunner() + ssl_mode = 'False' + sql = 'select * from performance_schema.session_status where variable_name = "Ssl_cipher"' + result = runner.invoke(cli, args=CLI_ARGS + ['--csv', '--ssl-mode', ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split('\n'))) + ssl_cipher = result_dict.get('VARIABLE_VALUE', None) + assert not ssl_cipher + + @dbtest def test_ssl_mode_overrides_ssl(executor, capsys): runner = CliRunner()