Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 93 additions & 2 deletions src/git/src/mcp_server_git/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from pathlib import Path
from typing import Sequence, Optional
from mcp.server import Server
Expand Down Expand Up @@ -92,6 +93,19 @@ class GitBranch(BaseModel):
)


class GitCurrentBranch(BaseModel):
repo_path: str

class GitDefaultBranch(BaseModel):
repo_path: str
remote: str = Field(
"origin",
description="The remote to get the default branch for (defaults to 'origin')"
)

class GitRemote(BaseModel):
repo_path: str

class GitTools(str, Enum):
STATUS = "git_status"
DIFF_UNSTAGED = "git_diff_unstaged"
Expand All @@ -106,6 +120,9 @@ class GitTools(str, Enum):
SHOW = "git_show"

BRANCH = "git_branch"
CURRENT_BRANCH = "git_current_branch"
DEFAULT_BRANCH = "git_default_branch"
REMOTE = "git_remote"

def git_status(repo: git.Repo) -> str:
return repo.git.status()
Expand Down Expand Up @@ -268,6 +285,42 @@ def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, no

return branch_info

def git_current_branch(repo: git.Repo) -> str:
if repo.head.is_detached:
return f"HEAD detached at {repo.head.commit.hexsha[:7]}"
return repo.active_branch.name

def git_default_branch(repo: git.Repo, remote: str = "origin") -> str:
# Try git ls-remote --symref to detect remote HEAD
try:
output = repo.git.ls_remote("--symref", remote, "HEAD")
# Output format: "ref: refs/heads/main\tHEAD\n<sha>\tHEAD"
match = re.search(r"^ref: refs/heads/(\S+)\t", output, re.MULTILINE)
if match:
return f"{remote}/{match.group(1)}"
except git.GitCommandError:
pass

# Try local ref resolution via rev-parse (returns "origin/main" directly)
try:
return repo.git.rev_parse("--abbrev-ref", f"{remote}/HEAD")
except git.GitCommandError:
pass

# Fallback: check for common local branch names
local_branches = [ref.name for ref in repo.branches]
if "main" in local_branches:
return f"{remote}/main"
if "master" in local_branches:
return f"{remote}/master"

raise ValueError(
f"Could not determine the default branch for remote '{remote}'"
)

def git_remote(repo: git.Repo) -> str:
return repo.git.remote("-v")


async def serve(repository: Path | None) -> None:
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -345,8 +398,22 @@ async def list_tools() -> list[Tool]:
name=GitTools.BRANCH,
description="List Git branches",
inputSchema=GitBranch.model_json_schema(),

)
),
Tool(
name=GitTools.CURRENT_BRANCH,
description="Returns the name of the currently checked out branch, or the commit SHA if HEAD is detached",
inputSchema=GitCurrentBranch.model_json_schema(),
),
Tool(
name=GitTools.DEFAULT_BRANCH,
description="Returns the default branch for a remote in '<remote>/<branch>' format (e.g., 'origin/main'). Queries the remote directly, with fallback to local ref resolution and branch detection.",
inputSchema=GitDefaultBranch.model_json_schema(),
),
Tool(
name=GitTools.REMOTE,
description="Lists all configured remotes with their fetch and push URLs",
inputSchema=GitRemote.model_json_schema(),
),
]

async def list_repos() -> Sequence[str]:
Expand Down Expand Up @@ -488,6 +555,30 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
text=result
)]

case GitTools.CURRENT_BRANCH:
result = git_current_branch(repo)
return [TextContent(
type="text",
text=result
)]

case GitTools.DEFAULT_BRANCH:
result = git_default_branch(
repo,
arguments.get("remote", "origin")
)
return [TextContent(
type="text",
text=result
)]

case GitTools.REMOTE:
result = git_remote(repo)
return [TextContent(
type="text",
text=result
)]

case _:
raise ValueError(f"Unknown tool: {name}")

Expand Down
145 changes: 145 additions & 0 deletions src/git/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from mcp_server_git.server import (
git_checkout,
git_branch,
git_current_branch,
git_default_branch,
git_remote,
git_add,
git_status,
git_diff_unstaged,
Expand Down Expand Up @@ -423,3 +426,145 @@ def test_git_checkout_rejects_malicious_refs(test_repository):

# Cleanup
malicious_ref_path.unlink()


# Tests for git_current_branch

def test_git_current_branch(test_repository):
result = git_current_branch(test_repository)
assert result == test_repository.active_branch.name

def test_git_current_branch_detached_head(test_repository):
commit_sha = test_repository.head.commit.hexsha
test_repository.git.checkout(commit_sha)
result = git_current_branch(test_repository)
assert "detached" in result.lower()
assert commit_sha[:7] in result


# Tests for git_default_branch

def test_git_default_branch_fallback_local(test_repository):
"""Repo with no remote; falls back to detecting the local default branch name."""
default_branch = test_repository.active_branch.name
result = git_default_branch(test_repository)
assert result == f"origin/{default_branch}"

def test_git_default_branch_with_remote(tmp_path):
"""Create a bare remote repo, add it as origin, verify ls-remote detection works."""
# Create a bare repo to act as the remote
bare_path = tmp_path / "bare_remote.git"
bare_repo = git.Repo.init(bare_path, bare=True)

# Create a local repo and push to the bare remote
local_path = tmp_path / "local_repo"
local_repo = git.Repo.init(local_path)

Path(local_path / "test.txt").write_text("test")
local_repo.index.add(["test.txt"])
local_repo.index.commit("initial commit")

local_repo.create_remote("origin", str(bare_path))
local_repo.git.push("--set-upstream", "origin", local_repo.active_branch.name)

result = git_default_branch(local_repo)
assert result == f"origin/{local_repo.active_branch.name}"

shutil.rmtree(local_path)
shutil.rmtree(bare_path)

def test_git_default_branch_custom_remote(tmp_path):
"""Add a remote with a non-'origin' name, verify the remote parameter selects it."""
bare_path = tmp_path / "custom_remote.git"
bare_repo = git.Repo.init(bare_path, bare=True)

local_path = tmp_path / "local_repo"
local_repo = git.Repo.init(local_path)

Path(local_path / "test.txt").write_text("test")
local_repo.index.add(["test.txt"])
local_repo.index.commit("initial commit")

local_repo.create_remote("upstream", str(bare_path))
local_repo.git.push("--set-upstream", "upstream", local_repo.active_branch.name)

result = git_default_branch(local_repo, remote="upstream")
assert result == f"upstream/{local_repo.active_branch.name}"

shutil.rmtree(local_path)
shutil.rmtree(bare_path)

def test_git_default_branch_undetectable(tmp_path):
"""Repo with no remotes and no main/master branch; should raise ValueError."""
repo_path = tmp_path / "no_default_repo"
repo = git.Repo.init(repo_path)

# Create a commit on a non-standard branch name
repo.git.checkout("-b", "develop")
Path(repo_path / "test.txt").write_text("test")
repo.index.add(["test.txt"])
repo.index.commit("initial commit")

with pytest.raises(ValueError, match="Could not determine the default branch"):
git_default_branch(repo)

shutil.rmtree(repo_path)

def test_git_default_branch_revparse_fallback(tmp_path):
"""When ls-remote fails but local ref cache exists, rev-parse fallback should work."""
# Create a bare repo to act as the remote
bare_path = tmp_path / "bare_remote.git"
git.Repo.init(bare_path, bare=True)

# Create a local repo and push to the bare remote
local_path = tmp_path / "local_repo"
local_repo = git.Repo.init(local_path)

Path(local_path / "test.txt").write_text("test")
local_repo.index.add(["test.txt"])
local_repo.index.commit("initial commit")

active_branch = local_repo.active_branch.name
local_repo.create_remote("origin", str(bare_path))
local_repo.git.push("--set-upstream", "origin", active_branch)

# Populate local ref cache for origin/HEAD
local_repo.git.remote("set-head", "origin", "--auto")

# Replace remote URL with an invalid path so ls-remote will fail
local_repo.git.remote("set-url", "origin", "/nonexistent/path")

result = git_default_branch(local_repo)
assert result == f"origin/{active_branch}"

shutil.rmtree(local_path)
shutil.rmtree(bare_path)


# Tests for git_remote

def test_git_remote_no_remotes(test_repository):
"""Repo with no remotes; verify empty output."""
result = git_remote(test_repository)
assert result == ""

def test_git_remote_with_remote(tmp_path):
"""Repo with a remote configured; verify remote name and URL appear in output."""
bare_path = tmp_path / "bare_remote.git"
git.Repo.init(bare_path, bare=True)

local_path = tmp_path / "local_repo"
local_repo = git.Repo.init(local_path)

Path(local_path / "test.txt").write_text("test")
local_repo.index.add(["test.txt"])
local_repo.index.commit("initial commit")

local_repo.create_remote("origin", str(bare_path))

result = git_remote(local_repo)
assert "origin" in result
assert str(bare_path) in result

shutil.rmtree(local_path)
shutil.rmtree(bare_path)