Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions agents/coordinator_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ project before making changes so you can verify your setup works.
- Information about existing PRs — what they change, whether they look correct
- Anything else the worker agent should know

**5. Your recommended approach.** What you think the fix should look like. Be specific — name files, functions, line numbers. Frame it as guidance, not commands — the worker agent may find things you didn't and should use its own judgment.
**5. Your recommended approach.** What you think the fix should look like. Be specific — name files, functions, line numbers. Frame it as guidance, not commands — the worker agent may find things you didn't and should use its own judgment. Include which specific test file(s) or test function(s) the agent should run to verify its fix — not the full suite.

**6. Completion workflow.** Every prompt file must include this section verbatim, with the issue number filled in:

Expand All @@ -130,20 +130,29 @@ project before making changes so you can verify your setup works.

After implementing and verifying the fix:

1. **Commit** your changes with a message referencing the issue:
1. **Run only the tests relevant to your change.** Do NOT run the full
test suite — it takes 10+ minutes and will be run separately later.
Instead, run the specific test file(s) that cover the code you changed:

pytest tests/test_autograd.py -v --tb=short -k "relevant_test_name"

If you wrote a new test, run that plus the existing tests in the same
file to check for regressions in that area.

2. **Commit** your changes with a message referencing the issue:

git add <files>
git commit -m "Fix <brief description> (#<NUMBER>)"

2. **Push** the branch:
3. **Push** the branch:

git push -u origin fix/issue-<NUMBER>

3. **Create a pull request** with `gh pr create`. The PR body must
4. **Create a pull request** with `gh pr create`. The PR body must
include "Fixes #<NUMBER>" so GitHub auto-links and auto-closes the
issue on merge. Describe what the fix does and how you verified it.

4. **Post to the bitsandbytes Slack channel** to notify the team.
5. **Post to the bitsandbytes Slack channel** to notify the team.
Write a temporary Python script to `/tmp/slack_notify.py` and run it:

import json, urllib.request, sys
Expand Down Expand Up @@ -245,7 +254,8 @@ whether it is correct and complete before implementing from scratch.

## When You Are Done

[the standard completion workflow section with issue number 1810 filled in]
[the standard completion workflow section with issue number 1810 filled in.
Remember: tell the agent to run only the relevant tests, not the full suite.]

## What NOT to Do

Expand Down
16 changes: 11 additions & 5 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,9 +1007,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
scb_name = "SCB"

# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, scb_name)
param_from_weight = getattr(self.weight, scb_name, None)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, scb_name)
param_from_state = getattr(self.state, scb_name, None)

key_name = prefix + f"{scb_name}"

Expand Down Expand Up @@ -1048,18 +1048,19 @@ def _load_from_state_dict(
for key in unexpected_copy:
input_name = key[len(prefix) :]
if input_name == "SCB":
if self.weight.SCB is None:
weight_scb = getattr(self.weight, "SCB", None)
if weight_scb is None:
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()",
)

input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
weight_scb.copy_(input_param)

if self.state.SCB is not None:
self.state.SCB = self.weight.SCB
self.state.SCB = weight_scb

unexpected_keys.remove(key)

Expand All @@ -1085,6 +1086,11 @@ def to(self, *args, **kwargs):
return result

def forward(self, x: torch.Tensor):
# If weight is not Int8Params (e.g. due to weight tying with a non-quantized module
# like an embedding layer), fall back to regular linear. See issue #1634.
if not isinstance(self.weight, Int8Params):
return torch.nn.functional.linear(x, self.weight, self.bias)

self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
Expand Down
4 changes: 1 addition & 3 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage):
reassembled = torch.cat(shards).reshape(qB.shape)

assert reassembled.dtype == qB.dtype
assert torch.equal(
reassembled.view(torch.uint8), qB.view(torch.uint8)
), "Bytes changed after shard roundtrip"
assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip"

out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state)
torch.testing.assert_close(out, ref)
Expand Down
68 changes: 68 additions & 0 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,71 @@ def test_linear8bitlt_device_movement(device):

# Accelerator outputs should match both times.
torch.testing.assert_close(out_accelerator_2, out_accelerator, rtol=1e-8, atol=1e-8)


class TiedWeightModel(torch.nn.Module):
"""A minimal model with tied weights between an embedding and lm_head, mimicking
architectures like OPT where lm_head.weight is shared with the embedding layer."""

def __init__(self, vocab_size, hidden_dim):
super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, hidden_dim)
self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim)
self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim)
self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim)
self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False)
# Tie weights
self.lm_head.weight = self.embed_tokens.weight

def forward(self, x):
h = self.embed_tokens(x)
h = self.out_proj(self.q_proj(h) + self.v_proj(h))
return self.lm_head(h)


@pytest.mark.parametrize("device", get_available_devices())
def test_linear8bitlt_tied_weights_no_crash(device):
"""Test that Linear8bitLt gracefully handles tied weights (issue #1634).

When lm_head is replaced with Linear8bitLt but its weight is tied to
an embedding layer, the weight becomes a regular Parameter instead of
Int8Params. The forward pass should still work via F.linear fallback.
"""
vocab_size, hidden_dim = 32, 64
model = TiedWeightModel(vocab_size, hidden_dim)

skip_modules = ["q_proj", "v_proj"]

# Replace non-skipped linear layers with Linear8bitLt (simulating what
# HuggingFace transformers does with llm_int8_skip_modules)
from bitsandbytes.utils import replace_linear

model = replace_linear(
model,
lambda inf, outf, bias: Linear8bitLt(inf, outf, bias=bias, has_fp16_weights=False),
skip_modules=skip_modules,
copy_weights=True,
)

# Re-tie weights (as transformers does after module replacement)
model.lm_head.weight = model.embed_tokens.weight

model = model.to(device)

# Verify: skipped modules remain nn.Linear
assert type(model.q_proj) is torch.nn.Linear, "q_proj should remain nn.Linear"
assert type(model.v_proj) is torch.nn.Linear, "v_proj should remain nn.Linear"

# Verify: non-skipped, non-tied modules are Linear8bitLt
assert isinstance(model.out_proj, Linear8bitLt), "out_proj should be Linear8bitLt"

# Verify: lm_head is Linear8bitLt but with a regular Parameter (from tying)
assert isinstance(model.lm_head, Linear8bitLt), "lm_head should be Linear8bitLt"
assert not isinstance(model.lm_head.weight, bnb.nn.Int8Params), (
"lm_head.weight should be a regular Parameter due to tying"
)

# Forward pass should NOT crash (this was the bug in issue #1634)
x = torch.randint(0, vocab_size, (2, 8), device=device)
output = model(x)
assert output.shape == (2, 8, vocab_size)
Loading