Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ public enum Builtins {
LMDS("lmDS", true),
LMPREDICT("lmPredict", true),
LMPREDICT_STATS("lmPredictStats", true),
LLMPREDICT("llmPredict", false, true),
LOCAL("local", false),
LOG("log", false),
LOGSUMEXP("logSumExp", true),
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Opcodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ public enum Opcodes {
GROUPEDAGG("groupedagg", InstructionType.ParameterizedBuiltin),
RMEMPTY("rmempty", InstructionType.ParameterizedBuiltin),
REPLACE("replace", InstructionType.ParameterizedBuiltin),
LLMPREDICT("llmpredict", InstructionType.ParameterizedBuiltin),
LOWERTRI("lowertri", InstructionType.ParameterizedBuiltin),
UPPERTRI("uppertri", InstructionType.ParameterizedBuiltin),
REXPAND("rexpand", InstructionType.ParameterizedBuiltin),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) {

/** Parameterized operations that require named variable arguments */
public enum ParamBuiltinOp {
AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, LLMPREDICT, RMEMPTY, REPLACE, REXPAND,
LOWER_TRI, UPPER_TRI,
TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
TOKENIZE, TOSTRING, LIST, PARAMSERV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ public Lop constructLops()
case LOWER_TRI:
case UPPER_TRI:
case TOKENIZE:
case LLMPREDICT:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
Expand Down Expand Up @@ -758,7 +759,7 @@ && getTargetHop().areDimsBelowThreshold() ) {
if (_op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA
|| _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST
|| _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF
|| _op == ParamBuiltinOp.PARAMSERV) {
|| _op == ParamBuiltinOp.PARAMSERV || _op == ParamBuiltinOp.LLMPREDICT) {
_etype = ExecType.CP;
}

Expand All @@ -768,7 +769,7 @@ && getTargetHop().areDimsBelowThreshold() ) {
switch(_op) {
case CONTAINS:
if(getTargetHop().optFindExecType() == ExecType.SPARK)
_etype = ExecType.SPARK;
_etype = ExecType.SPARK;
break;
default:
// Do not change execution type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ public String getInstructions(String output)
case CONTAINS:
case REPLACE:
case TOKENIZE:
case LLMPREDICT:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu
case LOWER_TRI:
case UPPER_TRI:
case TOKENIZE:
case LLMPREDICT:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG);
pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY);
pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE);
pbHopMap.put(Builtins.LLMPREDICT, ParamBuiltinOp.LLMPREDICT);
pbHopMap.put(Builtins.LOWER_TRI, ParamBuiltinOp.LOWER_TRI);
pbHopMap.put(Builtins.UPPER_TRI, ParamBuiltinOp.UPPER_TRI);

Expand Down Expand Up @@ -211,6 +212,10 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
validateOrder(output, conditional);
break;

case LLMPREDICT:
validateLlmPredict(output, conditional);
break;

case TOKENIZE:
validateTokenize(output, conditional);
break;
Expand Down Expand Up @@ -614,6 +619,42 @@ private void validateTokenize(DataIdentifier output, boolean conditional)
output.setDimensions(-1, -1);
}

private void validateLlmPredict(DataIdentifier output, boolean conditional)
{
Set<String> valid = new HashSet<>(Arrays.asList(
"target", "url", "model", "max_tokens", "temperature", "top_p", "concurrency"));
checkInvalidParameters(getOpCode(), getVarParams(), valid);
checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional);
checkStringParam(false, "llmPredict", "url", conditional);

// validate numeric parameter types at compile time (when literal).
// Note: no range validation -- different LLM servers accept different
// ranges (e.g. vLLM allows temperature=0.0, OpenAI requires >0).
// Runtime errors from the server are more informative than
// compile-time checks locked to one server's rules.
checkNumericScalarParam("llmPredict", "max_tokens", conditional);
checkNumericScalarParam("llmPredict", "temperature", conditional);
checkNumericScalarParam("llmPredict", "top_p", conditional);
checkNumericScalarParam("llmPredict", "concurrency", conditional);

output.setDataType(DataType.FRAME);
output.setValueType(ValueType.STRING);
output.setDimensions(-1, -1);
}

private void checkNumericScalarParam(String fname, String pname, boolean conditional) {
Expression expr = getVarParam(pname);
if(expr == null) return;
if(expr instanceof DataIdentifier) {
DataIdentifier di = (DataIdentifier) expr;
if(di.getDataType() != null && !di.getDataType().isScalar()) {
raiseValidateError(
String.format("Function %s: parameter '%s' must be a scalar, got %s.",
fname, pname, di.getDataType()), conditional);
}
}
}

// example: A = transformapply(target=X, meta=M, spec=s)
private void validateTransformApply(DataIdentifier output, boolean conditional)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.cp;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ConnectException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.wink.json4j.JSONObject;

public class LlmPredictCPInstruction extends ParameterizedBuiltinCPInstruction {

protected LlmPredictCPInstruction(LinkedHashMap<String, String> paramsMap,
CPOperand out, String opcode, String istr) {
super(null, paramsMap, out, opcode, istr);
}

@Override
public void processInstruction(ExecutionContext ec) {
FrameBlock prompts = ec.getFrameInput(params.get("target"));
String url = params.get("url");
String model = params.containsKey("model") ?
params.get("model") : null;
int maxTokens = params.containsKey("max_tokens") ?
Integer.parseInt(params.get("max_tokens")) : 512;
double temperature = params.containsKey("temperature") ?
Double.parseDouble(params.get("temperature")) : 0.0;
double topP = params.containsKey("top_p") ?
Double.parseDouble(params.get("top_p")) : 0.9;
int concurrency = params.containsKey("concurrency") ?
Integer.parseInt(params.get("concurrency")) : 1;
concurrency = Math.max(1, Math.min(concurrency, 128));

int n = prompts.getNumRows();
String[][] data = new String[n][];

List<Callable<String[]>> tasks = new ArrayList<>(n);
for(int i = 0; i < n; i++) {
String prompt = prompts.get(i, 0).toString();
tasks.add(() -> callLlmEndpoint(prompt, url, model, maxTokens, temperature, topP));
}

try {
if(concurrency <= 1) {
for(int i = 0; i < n; i++)
data[i] = tasks.get(i).call();
}
else {
ExecutorService pool = Executors.newFixedThreadPool(
Math.min(concurrency, n));
List<Future<String[]>> futures = pool.invokeAll(tasks);
pool.shutdown();
for(int i = 0; i < n; i++)
data[i] = futures.get(i).get();
}
}
catch(DMLRuntimeException e) {
throw e;
}
catch(Exception e) {
throw new DMLRuntimeException("llmPredict failed: " + e.getMessage(), e);
}

ValueType[] schema = {ValueType.STRING, ValueType.STRING,
ValueType.INT64, ValueType.INT64, ValueType.INT64};
String[] colNames = {"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"};
FrameBlock fbout = new FrameBlock(schema, colNames);
for(String[] row : data)
fbout.appendRow(row);

ec.setFrameOutput(output.getName(), fbout);
ec.releaseFrameInput(params.get("target"));
}

// No retry logic by design: as a database built-in, llmPredict should
// fail fast on transient errors and let the caller (DML script) decide
// whether and how to retry. Silent retries with backoff would make
// execution time unpredictable.
private static String[] callLlmEndpoint(String prompt, String url,
String model, int maxTokens, double temperature, double topP) {
long t0 = System.nanoTime();

// validate URL and open connection
HttpURLConnection conn;
try {
conn = (HttpURLConnection) new URI(url).toURL().openConnection();
}
catch(URISyntaxException | MalformedURLException | IllegalArgumentException e) {
throw new DMLRuntimeException(
"llmPredict: invalid URL '" + url + "'. "
+ "Expected format: http://host:port/v1/completions", e);
}
catch(IOException e) {
throw new DMLRuntimeException(
"llmPredict: cannot open connection to '" + url + "'.", e);
}

try {
JSONObject req = new JSONObject();
if(model != null)
req.put("model", model);
req.put("prompt", prompt);
req.put("max_tokens", maxTokens);
req.put("temperature", temperature);
req.put("top_p", topP);

conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/json");
conn.setConnectTimeout(10_000);
conn.setReadTimeout(300_000);
conn.setDoOutput(true);

try(OutputStream os = conn.getOutputStream()) {
os.write(req.toString().getBytes(StandardCharsets.UTF_8));
}

int httpCode = conn.getResponseCode();
if(httpCode != 200) {
String errBody = "";
try(InputStream es = conn.getErrorStream()) {
if(es != null)
errBody = new String(es.readAllBytes(), StandardCharsets.UTF_8);
}
catch(Exception ignored) {}
throw new DMLRuntimeException(
"llmPredict: endpoint returned HTTP " + httpCode
+ " for '" + url + "'."
+ (errBody.isEmpty() ? "" : " Response: " + errBody));
}

String body;
try(InputStream is = conn.getInputStream()) {
body = new String(is.readAllBytes(), StandardCharsets.UTF_8);
}

JSONObject resp = new JSONObject(body);
if(!resp.has("choices") || resp.getJSONArray("choices").length() == 0) {
String errMsg = resp.has("error") ? resp.optString("error") : body;
throw new DMLRuntimeException(
"llmPredict: server response missing 'choices'. Response: " + errMsg);
}
String text = resp.getJSONArray("choices")
.getJSONObject(0).getString("text");
long elapsed = (System.nanoTime() - t0) / 1_000_000;
int inTok = 0, outTok = 0;
if(resp.has("usage")) {
JSONObject usage = resp.getJSONObject("usage");
inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0;
outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0;
}
return new String[]{prompt, text,
String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)};
}
catch(ConnectException e) {
throw new DMLRuntimeException(
"llmPredict: connection refused to '" + url + "'. "
+ "Ensure the LLM server is running and reachable.", e);
}
catch(SocketTimeoutException e) {
throw new DMLRuntimeException(
"llmPredict: timed out connecting to '" + url + "'. "
+ "Ensure the LLM server is running and reachable.", e);
}
catch(IOException e) {
throw new DMLRuntimeException(
"llmPredict: I/O error communicating with '" + url + "'.", e);
}
catch(DMLRuntimeException e) {
throw e;
}
catch(Exception e) {
throw new DMLRuntimeException(
"llmPredict: failed to get response from '" + url + "'.", e);
}
finally {
conn.disconnect();
}
}

@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
CPOperand target = new CPOperand(params.get("target"), ValueType.STRING, DataType.FRAME);
CPOperand urlOp = new CPOperand(params.get("url"), ValueType.STRING, DataType.SCALAR, true);
return Pair.of(output.getName(),
new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, urlOp)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ else if(opcode.equals(Opcodes.TRANSFORMAPPLY.toString()) || opcode.equals(Opcode
|| opcode.equals(Opcodes.TOSTRING.toString()) || opcode.equals(Opcodes.NVLIST.toString()) || opcode.equals(Opcodes.AUTODIFF.toString())) {
return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str);
}
else if(opcode.equals(Opcodes.LLMPREDICT.toString())) {
return new LlmPredictCPInstruction(paramsMap, out, opcode, str);
}
else if(Opcodes.PARAMSERV.toString().equals(opcode)) {
return new ParamservBuiltinCPInstruction(null, paramsMap, out, opcode, str);
}
Expand Down Expand Up @@ -324,6 +327,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) {
ec.setFrameOutput(output.getName(), fbout);
ec.releaseFrameInput(params.get("target"));
}

else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) {
// acquire locks
FrameBlock data = ec.getFrameInput(params.get("target"));
Expand Down
Loading