diff --git a/agents/coordinator_guide.md b/agents/coordinator_guide.md index dfc5a8b29..2608fe97d 100644 --- a/agents/coordinator_guide.md +++ b/agents/coordinator_guide.md @@ -121,7 +121,7 @@ project before making changes so you can verify your setup works. - Information about existing PRs — what they change, whether they look correct - Anything else the worker agent should know -**5. Your recommended approach.** What you think the fix should look like. Be specific — name files, functions, line numbers. Frame it as guidance, not commands — the worker agent may find things you didn't and should use its own judgment. +**5. Your recommended approach.** What you think the fix should look like. Be specific — name files, functions, line numbers. Frame it as guidance, not commands — the worker agent may find things you didn't and should use its own judgment. Include which specific test file(s) or test function(s) the agent should run to verify its fix — not the full suite. **6. Completion workflow.** Every prompt file must include this section verbatim, with the issue number filled in: @@ -130,20 +130,29 @@ project before making changes so you can verify your setup works. After implementing and verifying the fix: -1. **Commit** your changes with a message referencing the issue: +1. **Run only the tests relevant to your change.** Do NOT run the full + test suite — it takes 10+ minutes and will be run separately later. + Instead, run the specific test file(s) that cover the code you changed: + + pytest tests/test_autograd.py -v --tb=short -k "relevant_test_name" + + If you wrote a new test, run that plus the existing tests in the same + file to check for regressions in that area. + +2. **Commit** your changes with a message referencing the issue: git add git commit -m "Fix (#)" -2. **Push** the branch: +3. **Push** the branch: git push -u origin fix/issue- -3. **Create a pull request** with `gh pr create`. The PR body must +4. **Create a pull request** with `gh pr create`. The PR body must include "Fixes #" so GitHub auto-links and auto-closes the issue on merge. Describe what the fix does and how you verified it. -4. **Post to the bitsandbytes Slack channel** to notify the team. +5. **Post to the bitsandbytes Slack channel** to notify the team. Write a temporary Python script to `/tmp/slack_notify.py` and run it: import json, urllib.request, sys @@ -245,7 +254,8 @@ whether it is correct and complete before implementing from scratch. ## When You Are Done -[the standard completion workflow section with issue number 1810 filled in] +[the standard completion workflow section with issue number 1810 filled in. +Remember: tell the agent to run only the relevant tests, not the full suite.] ## What NOT to Do diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9c0a647cb..3bd84d422 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1007,9 +1007,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): scb_name = "SCB" # case 1: .cuda was called, SCB is in self.weight - param_from_weight = getattr(self.weight, scb_name) + param_from_weight = getattr(self.weight, scb_name, None) # case 2: self.init_8bit_state was called, SCB is in self.state - param_from_state = getattr(self.state, scb_name) + param_from_state = getattr(self.state, scb_name, None) key_name = prefix + f"{scb_name}" @@ -1048,7 +1048,8 @@ def _load_from_state_dict( for key in unexpected_copy: input_name = key[len(prefix) :] if input_name == "SCB": - if self.weight.SCB is None: + weight_scb = getattr(self.weight, "SCB", None) + if weight_scb is None: # buffers not yet initialized, can't access them directly without quantizing first raise RuntimeError( "Loading a quantized checkpoint into non-quantized Linear8bitLt is " @@ -1056,10 +1057,10 @@ def _load_from_state_dict( ) input_param = state_dict[key] - self.weight.SCB.copy_(input_param) + weight_scb.copy_(input_param) if self.state.SCB is not None: - self.state.SCB = self.weight.SCB + self.state.SCB = weight_scb unexpected_keys.remove(key) @@ -1085,6 +1086,11 @@ def to(self, *args, **kwargs): return result def forward(self, x: torch.Tensor): + # If weight is not Int8Params (e.g. due to weight tying with a non-quantized module + # like an embedding layer), fall back to regular linear. See issue #1634. + if not isinstance(self.weight, Int8Params): + return torch.nn.functional.linear(x, self.weight, self.bias) + self.state.is_training = self.training if self.weight.CB is not None: self.init_8bit_state() diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index ee8bafe80..de40d158c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage): reassembled = torch.cat(shards).reshape(qB.shape) assert reassembled.dtype == qB.dtype - assert torch.equal( - reassembled.view(torch.uint8), qB.view(torch.uint8) - ), "Bytes changed after shard roundtrip" + assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip" out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state) torch.testing.assert_close(out, ref) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 83f207d42..f64db0a2b 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -335,3 +335,71 @@ def test_linear8bitlt_device_movement(device): # Accelerator outputs should match both times. torch.testing.assert_close(out_accelerator_2, out_accelerator, rtol=1e-8, atol=1e-8) + + +class TiedWeightModel(torch.nn.Module): + """A minimal model with tied weights between an embedding and lm_head, mimicking + architectures like OPT where lm_head.weight is shared with the embedding layer.""" + + def __init__(self, vocab_size, hidden_dim): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, hidden_dim) + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False) + # Tie weights + self.lm_head.weight = self.embed_tokens.weight + + def forward(self, x): + h = self.embed_tokens(x) + h = self.out_proj(self.q_proj(h) + self.v_proj(h)) + return self.lm_head(h) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_linear8bitlt_tied_weights_no_crash(device): + """Test that Linear8bitLt gracefully handles tied weights (issue #1634). + + When lm_head is replaced with Linear8bitLt but its weight is tied to + an embedding layer, the weight becomes a regular Parameter instead of + Int8Params. The forward pass should still work via F.linear fallback. + """ + vocab_size, hidden_dim = 32, 64 + model = TiedWeightModel(vocab_size, hidden_dim) + + skip_modules = ["q_proj", "v_proj"] + + # Replace non-skipped linear layers with Linear8bitLt (simulating what + # HuggingFace transformers does with llm_int8_skip_modules) + from bitsandbytes.utils import replace_linear + + model = replace_linear( + model, + lambda inf, outf, bias: Linear8bitLt(inf, outf, bias=bias, has_fp16_weights=False), + skip_modules=skip_modules, + copy_weights=True, + ) + + # Re-tie weights (as transformers does after module replacement) + model.lm_head.weight = model.embed_tokens.weight + + model = model.to(device) + + # Verify: skipped modules remain nn.Linear + assert type(model.q_proj) is torch.nn.Linear, "q_proj should remain nn.Linear" + assert type(model.v_proj) is torch.nn.Linear, "v_proj should remain nn.Linear" + + # Verify: non-skipped, non-tied modules are Linear8bitLt + assert isinstance(model.out_proj, Linear8bitLt), "out_proj should be Linear8bitLt" + + # Verify: lm_head is Linear8bitLt but with a regular Parameter (from tying) + assert isinstance(model.lm_head, Linear8bitLt), "lm_head should be Linear8bitLt" + assert not isinstance(model.lm_head.weight, bnb.nn.Int8Params), ( + "lm_head.weight should be a regular Parameter due to tying" + ) + + # Forward pass should NOT crash (this was the bug in issue #1634) + x = torch.randint(0, vocab_size, (2, 8), device=device) + output = model(x) + assert output.shape == (2, 8, vocab_size)