diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 8ce8ec64..2c67d511 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -13,6 +13,10 @@ class Requirement: UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?') UNTIL_SEP = re.compile(rb'[^;\s]+') + _SPECIAL_ORDER = { + b'--index-url': 0, + b'--extra-index-url': 1, + } def __init__(self) -> None: self.value: bytes | None = None @@ -30,6 +34,10 @@ def name(self) -> bytes: assert m is not None name = m.group() + if name == b'-i' or name.startswith(b'--index-url='): + return b'--index-url' + elif name.startswith(b'--extra-index-url='): + return b'--extra-index-url' m = self.UNTIL_COMPARISON.search(name) if not m: return name @@ -50,7 +58,13 @@ def __lt__(self, requirement: Requirement) -> bool: # with comments is kept) if self.name == requirement.name: return bool(self.comments) > bool(requirement.comments) - return self.name < requirement.name + return ( + self._SPECIAL_ORDER.get(self.name, 2), + self.name, + ) < ( + self._SPECIAL_ORDER.get(requirement.name, 2), + requirement.name, + ) def is_complete(self) -> bool: return ( diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index c0d2c65d..6991bcd8 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -107,6 +107,33 @@ PASS, b'a=2.0.0 \\\n --hash=sha256:abcd\nb==1.0.0\n', ), + ( + b'--extra-index-url https://example.com/simple\n' + b'--index-url https://pypi.org/simple\n' + b'requests\n', + FAIL, + b'--index-url https://pypi.org/simple\n' + b'--extra-index-url https://example.com/simple\n' + b'requests\n', + ), + ( + b'--extra-index-url https://example.com/simple\n' + b'-i https://pypi.org/simple\n' + b'requests\n', + FAIL, + b'-i https://pypi.org/simple\n' + b'--extra-index-url https://example.com/simple\n' + b'requests\n', + ), + ( + b'--extra-index-url=https://example.com/simple\n' + b'--index-url=https://pypi.org/simple\n' + b'requests\n', + FAIL, + b'--index-url=https://pypi.org/simple\n' + b'--extra-index-url=https://example.com/simple\n' + b'requests\n', + ), ), ) def test_integration(input_s, expected_retval, output, tmpdir):