diff --git a/integration-test/src/test/java/org/apache/iotdb/libudf/it/dlearn/DLearnIT.java b/integration-test/src/test/java/org/apache/iotdb/libudf/it/dlearn/DLearnIT.java index 480dff4e46dca..b4a1c8daa92ab 100644 --- a/integration-test/src/test/java/org/apache/iotdb/libudf/it/dlearn/DLearnIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/libudf/it/dlearn/DLearnIT.java @@ -94,6 +94,11 @@ private static void createTimeSeries() { + "datatype=double, " + "encoding=plain, " + "compression=uncompressed"); + statement.addBatch( + "create timeseries root.vehicle.d3.s1 with " + + "datatype=int32, " + + "encoding=plain, " + + "compression=uncompressed"); statement.executeBatch(); } catch (SQLException throwable) { fail(throwable.getMessage()); @@ -168,6 +173,16 @@ private static void generateData() { String.format( "insert into root.vehicle.d2(timestamp,s1,s2,s3,s4) values(%d,%d,%d,%d,%d)", 900, 4, 4, 4, 4)); + // Toy series for cluster UDF (l=3, k=2): windows [1,2,3], [10,20,30], [1,5,1]. With + // norm=false, + // k-means groups the first two windows; k-shape / medoidshape group windows 0 and 2 + // (shape-related). + int[] toy = {1, 2, 3, 10, 20, 30, 1, 5, 1}; + for (int i = 0; i < toy.length; i++) { + statement.addBatch( + String.format( + "insert into root.vehicle.d3(timestamp,s1) values(%d,%d)", (i + 1) * 100, toy[i])); + } statement.executeBatch(); } catch (SQLException throwable) { fail(throwable.getMessage()); @@ -179,6 +194,7 @@ private static void registerUDF() { Statement statement = connection.createStatement()) { statement.execute("create function iqr as 'org.apache.iotdb.library.anomaly.UDTFIQR'"); statement.execute("create function ar as 'org.apache.iotdb.library.dlearn.UDTFAR'"); + statement.execute("create function cluster as 'org.apache.iotdb.library.dlearn.UDTFCluster'"); } catch (SQLException throwable) { fail(throwable.getMessage()); } @@ -308,4 +324,70 @@ public void testAR4() { fail(throwable.getMessage()); } } + + @Test + public void testCluster1() { + String sqlStr = + "select cluster(d3.s1, 'l'='3', 'k'='2', 'method'='kmeans', 'norm'='false', " + + "'maxiter'='50', 'output'='label') from root.vehicle"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(sqlStr)) { + resultSet.next(); + int l0 = resultSet.getInt(2); + resultSet.next(); + int l1 = resultSet.getInt(2); + resultSet.next(); + int l2 = resultSet.getInt(2); + Assert.assertFalse(resultSet.next()); + Assert.assertEquals(l0, l2); + Assert.assertNotEquals(l0, l1); + } catch (SQLException throwable) { + fail(throwable.getMessage()); + } + } + + @Test + public void testCluster2() { + String sqlStr = + "select cluster(d3.s1, 'l'='3', 'k'='2', 'method'='kshape', 'norm'='true', " + + "'maxiter'='50', 'output'='label') from root.vehicle"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(sqlStr)) { + resultSet.next(); + int l0 = resultSet.getInt(2); + resultSet.next(); + int l1 = resultSet.getInt(2); + resultSet.next(); + int l2 = resultSet.getInt(2); + Assert.assertFalse(resultSet.next()); + Assert.assertEquals(l0, l1); + Assert.assertNotEquals(l0, l2); + } catch (SQLException throwable) { + fail(throwable.getMessage()); + } + } + + @Test + public void testCluster3() { + String sqlStr = + "select cluster(d3.s1, 'l'='3', 'k'='2', 'method'='medoidshape', 'norm'='true', " + + "'sample_rate'='1', 'maxiter'='50', 'output'='label') from root.vehicle"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(sqlStr)) { + resultSet.next(); + int l0 = resultSet.getInt(2); + resultSet.next(); + int l1 = resultSet.getInt(2); + resultSet.next(); + int l2 = resultSet.getInt(2); + Assert.assertFalse(resultSet.next()); + Assert.assertEquals(l0, l1); + Assert.assertNotEquals(l0, l2); + } catch (SQLException throwable) { + fail(throwable.getMessage()); + } + } } diff --git a/library-udf/src/assembly/tools/register-UDF.bat b/library-udf/src/assembly/tools/register-UDF.bat index 0eb333d88c7bc..f30db5d2a3a7d 100644 --- a/library-udf/src/assembly/tools/register-UDF.bat +++ b/library-udf/src/assembly/tools/register-UDF.bat @@ -25,83 +25,84 @@ @REM Data Profiling -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function distinct as 'org.apache.iotdb.library.dprofile.UDTFDistinct'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function histogram as 'org.apache.iotdb.library.dprofile.UDTFHistogram'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function integral as 'org.apache.iotdb.library.dprofile.UDAFIntegral'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function integralavg as 'org.apache.iotdb.library.dprofile.UDAFIntegralAvg'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function mad as 'org.apache.iotdb.library.dprofile.UDAFMad'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function median as 'org.apache.iotdb.library.dprofile.UDAFMedian'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function percentile as 'org.apache.iotdb.library.dprofile.UDAFPercentile'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function quantile as 'org.apache.iotdb.library.dprofile.UDAFQuantile'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function period as 'org.apache.iotdb.library.dprofile.UDAFPeriod'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function qlb as 'org.apache.iotdb.library.dprofile.UDTFQLB'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function re_sample as 'org.apache.iotdb.library.dprofile.UDTFResample'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function sample as 'org.apache.iotdb.library.dprofile.UDTFSample'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function segment as 'org.apache.iotdb.library.dprofile.UDTFSegment'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function skew as 'org.apache.iotdb.library.dprofile.UDAFSkew'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function spread as 'org.apache.iotdb.library.dprofile.UDAFSpread'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function minmax as 'org.apache.iotdb.library.dprofile.UDTFMinMax'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function zscore as 'org.apache.iotdb.library.dprofile.UDTFZScore'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function spline as 'org.apache.iotdb.library.dprofile.UDTFSpline'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function mvavg as 'org.apache.iotdb.library.dprofile.UDTFMvAvg'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function acf as 'org.apache.iotdb.library.dprofile.UDTFACF'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function pacf as 'org.apache.iotdb.library.dprofile.UDTFPACF'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function distinct as 'org.apache.iotdb.library.dprofile.UDTFDistinct'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function histogram as 'org.apache.iotdb.library.dprofile.UDTFHistogram'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function integral as 'org.apache.iotdb.library.dprofile.UDAFIntegral'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function integralavg as 'org.apache.iotdb.library.dprofile.UDAFIntegralAvg'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function mad as 'org.apache.iotdb.library.dprofile.UDAFMad'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function median as 'org.apache.iotdb.library.dprofile.UDAFMedian'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function percentile as 'org.apache.iotdb.library.dprofile.UDAFPercentile'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function quantile as 'org.apache.iotdb.library.dprofile.UDAFQuantile'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function period as 'org.apache.iotdb.library.dprofile.UDAFPeriod'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function qlb as 'org.apache.iotdb.library.dprofile.UDTFQLB'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function re_sample as 'org.apache.iotdb.library.dprofile.UDTFResample'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function sample as 'org.apache.iotdb.library.dprofile.UDTFSample'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function segment as 'org.apache.iotdb.library.dprofile.UDTFSegment'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function skew as 'org.apache.iotdb.library.dprofile.UDAFSkew'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function spread as 'org.apache.iotdb.library.dprofile.UDAFSpread'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function minmax as 'org.apache.iotdb.library.dprofile.UDTFMinMax'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function zscore as 'org.apache.iotdb.library.dprofile.UDTFZScore'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function spline as 'org.apache.iotdb.library.dprofile.UDTFSpline'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function mvavg as 'org.apache.iotdb.library.dprofile.UDTFMvAvg'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function acf as 'org.apache.iotdb.library.dprofile.UDTFACF'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function pacf as 'org.apache.iotdb.library.dprofile.UDTFPACF'" @REM Data Quality -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function completeness as 'org.apache.iotdb.library.dquality.UDTFCompleteness'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function consistency as 'org.apache.iotdb.library.dquality.UDTFConsistency'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function timeliness as 'org.apache.iotdb.library.dquality.UDTFTimeliness'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function validity as 'org.apache.iotdb.library.dquality.UDTFValidity'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function completeness as 'org.apache.iotdb.library.dquality.UDTFCompleteness'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function consistency as 'org.apache.iotdb.library.dquality.UDTFConsistency'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function timeliness as 'org.apache.iotdb.library.dquality.UDTFTimeliness'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function validity as 'org.apache.iotdb.library.dquality.UDTFValidity'" @REM Data Repairing -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function timestamprepair as 'org.apache.iotdb.library.drepair.UDTFTimestampRepair'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function valuerepair as 'org.apache.iotdb.library.drepair.UDTFValueRepair'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function valuefill as 'org.apache.iotdb.library.drepair.UDTFValueFill'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function timestamprepair as 'org.apache.iotdb.library.drepair.UDTFTimestampRepair'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function valuerepair as 'org.apache.iotdb.library.drepair.UDTFValueRepair'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function valuefill as 'org.apache.iotdb.library.drepair.UDTFValueFill'" @REM Data Matching -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function cov as 'org.apache.iotdb.library.dmatch.UDAFCov'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function xcorr as 'org.apache.iotdb.library.dmatch.UDTFXCorr'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function dtw as 'org.apache.iotdb.library.dmatch.UDAFDtw'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function cov as 'org.apache.iotdb.library.dmatch.UDAFCov'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function xcorr as 'org.apache.iotdb.library.dmatch.UDTFXCorr'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function dtw as 'org.apache.iotdb.library.dmatch.UDAFDtw'" call ../bin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function ptnsym as 'org.apache.iotdb.library.dmatch.UDTFPtnSym'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function pearson as 'org.apache.iotdb.library.dmatch.UDAFPearson'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function pearson as 'org.apache.iotdb.library.dmatch.UDAFPearson'" @REM Anomaly Detection -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function ksigma as 'org.apache.iotdb.library.anomaly.UDTFKSigma'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function lof as 'org.apache.iotdb.library.anomaly.UDTFLOF'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function range as 'org.apache.iotdb.library.anomaly.UDTFRange'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function iqr as 'org.apache.iotdb.library.anomaly.UDTFIQR'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function twosidedfilter as 'org.apache.iotdb.library.anomaly.UDTFTwoSidedFilter'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function missdetect as 'org.apache.iotdb.library.anomaly.UDTFMissDetect'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function outlier as 'org.apache.iotdb.library.anomaly.UDTFOutlier'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function ksigma as 'org.apache.iotdb.library.anomaly.UDTFKSigma'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function lof as 'org.apache.iotdb.library.anomaly.UDTFLOF'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function range as 'org.apache.iotdb.library.anomaly.UDTFRange'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function iqr as 'org.apache.iotdb.library.anomaly.UDTFIQR'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function twosidedfilter as 'org.apache.iotdb.library.anomaly.UDTFTwoSidedFilter'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function missdetect as 'org.apache.iotdb.library.anomaly.UDTFMissDetect'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function outlier as 'org.apache.iotdb.library.anomaly.UDTFOutlier'" @REM Frequency Domain -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function fft as 'org.apache.iotdb.library.frequency.UDTFFFT'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function conv as 'org.apache.iotdb.library.frequency.UDTFConv'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function deconv as 'org.apache.iotdb.library.frequency.UDTFDeconv'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function lowpass as 'org.apache.iotdb.library.frequency.UDTFLowPass'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function highpass as 'org.apache.iotdb.library.frequency.UDTFHighPass'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function dwt as 'org.apache.iotdb.library.frequency.UDTFDWT'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function idwt as 'org.apache.iotdb.library.frequency.UDTFIDWT'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function ifft as 'org.apache.iotdb.library.frequency.UDTFIFFT'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function fft as 'org.apache.iotdb.library.frequency.UDTFFFT'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function conv as 'org.apache.iotdb.library.frequency.UDTFConv'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function deconv as 'org.apache.iotdb.library.frequency.UDTFDeconv'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function lowpass as 'org.apache.iotdb.library.frequency.UDTFLowPass'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function highpass as 'org.apache.iotdb.library.frequency.UDTFHighPass'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function dwt as 'org.apache.iotdb.library.frequency.UDTFDWT'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function idwt as 'org.apache.iotdb.library.frequency.UDTFIDWT'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function ifft as 'org.apache.iotdb.library.frequency.UDTFIFFT'" @REM Series Discovery -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function consecutivesequences as 'org.apache.iotdb.library.series.UDTFConsecutiveSequences'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function consecutivewindows as 'org.apache.iotdb.library.series.UDTFConsecutiveWindows'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function consecutivesequences as 'org.apache.iotdb.library.series.UDTFConsecutiveSequences'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function consecutivewindows as 'org.apache.iotdb.library.series.UDTFConsecutiveWindows'" @REM String Processing -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function regexsplit as 'org.apache.iotdb.library.string.UDTFRegexSplit'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function regexmatch as 'org.apache.iotdb.library.string.UDTFRegexMatch'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function strreplace as 'org.apache.iotdb.library.string.UDTFStrReplace'" -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function regexreplace as 'org.apache.iotdb.library.string.UDTFRegexReplace'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function regexsplit as 'org.apache.iotdb.library.string.UDTFRegexSplit'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function regexmatch as 'org.apache.iotdb.library.string.UDTFRegexMatch'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function strreplace as 'org.apache.iotdb.library.string.UDTFStrReplace'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function regexreplace as 'org.apache.iotdb.library.string.UDTFRegexReplace'" @REM Machine Learning -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function ar as 'org.apache.iotdb.library.dlearn.UDTFAR'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function ar as 'org.apache.iotdb.library.dlearn.UDTFAR'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function cluster as 'org.apache.iotdb.library.dlearn.UDTFCluster'" @REM Match -call ../sbin/windows/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function pattern_match as 'org.apache.iotdb.library.match.UDAFPatternMatch'" +call ../sbin/start-cli.bat -h %host% -p %rpcPort% -u %user% -pw %pass% -e "create function pattern_match as 'org.apache.iotdb.library.match.UDAFPatternMatch'" diff --git a/library-udf/src/assembly/tools/register-UDF.sh b/library-udf/src/assembly/tools/register-UDF.sh index 16ab59f143baa..faaa4df68b5d1 100755 --- a/library-udf/src/assembly/tools/register-UDF.sh +++ b/library-udf/src/assembly/tools/register-UDF.sh @@ -102,6 +102,7 @@ pass=root # Machine Learning ../sbin/start-cli.sh -h $host -p $rpcPort -u $user -pw $pass -e "create function ar as 'org.apache.iotdb.library.dlearn.UDTFAR'" +../sbin/start-cli.sh -h $host -p $rpcPort -u $user -pw $pass -e "create function cluster as 'org.apache.iotdb.library.dlearn.UDTFCluster'" # Match ../sbin/start-cli.sh -h $host -p $rpcPort -u $user -pw $pass -e "create function pattern_match as 'org.apache.iotdb.library.match.UDAFPatternMatch'" diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dlearn/UDTFCluster.java b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/UDTFCluster.java new file mode 100644 index 0000000000000..f64f9747992bc --- /dev/null +++ b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/UDTFCluster.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.library.dlearn; + +import org.apache.iotdb.library.dlearn.util.cluster.KMeans; +import org.apache.iotdb.library.dlearn.util.cluster.KShape; +import org.apache.iotdb.library.dlearn.util.cluster.MedoidShape; +import org.apache.iotdb.library.util.Util; +import org.apache.iotdb.udf.api.UDTF; +import org.apache.iotdb.udf.api.access.Row; +import org.apache.iotdb.udf.api.collector.PointCollector; +import org.apache.iotdb.udf.api.customizer.config.UDTFConfigurations; +import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator; +import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters; +import org.apache.iotdb.udf.api.customizer.strategy.RowByRowAccessStrategy; +import org.apache.iotdb.udf.api.exception.UDFException; +import org.apache.iotdb.udf.api.type.Type; + +import java.util.ArrayList; +import java.util.List; + +/** + * Clusters a time series by partitioning it into non-overlapping subsequences of length l. + * Parameters: l, k, method (default kmeans), norm, maxiter, output; medoidshape also uses + * sample_rate (greedy sampling ratio; use 1 when the window count is small). Requires at least k + * windows. + */ +public class UDTFCluster implements UDTF { + + private static final String METHOD_KMEANS = "kmeans"; + private static final String METHOD_KSHAPE = "kshape"; + private static final String METHOD_MEDOIDSHAPE = "medoidshape"; + + private static final String OUTPUT_LABEL = "label"; + private static final String OUTPUT_CENTROID = "centroid"; + + private static final int DEFAULT_MAX_ITER = 200; + private static final double DEFAULT_SAMPLE_RATE = 0.3; + private static final String DEFAULT_METHOD = METHOD_KMEANS; + + private int l; + private int k; + private String method; + private boolean norm; + private int maxIter; + private String output; + private double sampleRate; + + private final List timestamps = new ArrayList<>(); + private final List values = new ArrayList<>(); + + @Override + public void validate(UDFParameterValidator validator) throws Exception { + validator + .validateInputSeriesNumber(1) + .validateInputSeriesDataType(0, Type.INT32, Type.INT64, Type.FLOAT, Type.DOUBLE) + .validate( + x -> (int) x > 0, + "Parameter l must be a positive integer.", + validator.getParameters().getInt("l")) + .validate( + x -> (int) x >= 2, + "Parameter k must be at least 2.", + validator.getParameters().getInt("k")) + .validate( + x -> { + String m = ((String) x).toLowerCase(); + return METHOD_KMEANS.equals(m) + || METHOD_KSHAPE.equals(m) + || METHOD_MEDOIDSHAPE.equals(m); + }, + "Parameter method must be one of: kmeans, kshape, medoidshape.", + validator.getParameters().getStringOrDefault("method", DEFAULT_METHOD)) + .validate( + x -> (int) x >= 1, + "Parameter maxiter must be a positive integer.", + validator.getParameters().getIntOrDefault("maxiter", DEFAULT_MAX_ITER)) + .validate( + x -> { + String o = ((String) x).toLowerCase(); + return OUTPUT_LABEL.equals(o) || OUTPUT_CENTROID.equals(o); + }, + "Parameter output must be label or centroid.", + validator.getParameters().getStringOrDefault("output", OUTPUT_LABEL)) + .validate( + x -> { + double d = ((Number) x).doubleValue(); + return d > 0 && d <= 1.0; + }, + "Parameter sample_rate must be in (0, 1].", + validator.getParameters().getDoubleOrDefault("sample_rate", DEFAULT_SAMPLE_RATE)); + } + + @Override + public void beforeStart(UDFParameters parameters, UDTFConfigurations configurations) + throws Exception { + this.output = parameters.getStringOrDefault("output", OUTPUT_LABEL).toLowerCase(); + if (OUTPUT_CENTROID.equals(output)) { + configurations.setAccessStrategy(new RowByRowAccessStrategy()).setOutputDataType(Type.DOUBLE); + } else { + configurations.setAccessStrategy(new RowByRowAccessStrategy()).setOutputDataType(Type.INT32); + } + this.l = parameters.getInt("l"); + this.k = parameters.getInt("k"); + this.method = parameters.getStringOrDefault("method", DEFAULT_METHOD).toLowerCase(); + this.norm = parameters.getBooleanOrDefault("norm", true); + this.maxIter = parameters.getIntOrDefault("maxiter", DEFAULT_MAX_ITER); + this.sampleRate = parameters.getDoubleOrDefault("sample_rate", DEFAULT_SAMPLE_RATE); + timestamps.clear(); + values.clear(); + } + + @Override + public void transform(Row row, PointCollector collector) throws Exception { + if (!row.isNull(0)) { + timestamps.add(row.getTime()); + values.add(Util.getValueAsDouble(row)); + } + } + + @Override + public void terminate(PointCollector collector) throws Exception { + int n = values.size(); + if (n < l) { + throw new UDFException( + "Time series length must be at least l; got " + n + " points, l=" + l + "."); + } + int numWindows = n / l; + if (numWindows < k) { + throw new UDFException( + "Not enough non-overlapping windows: got " + + numWindows + + " windows, need at least k=" + + k + + "."); + } + + double[][] windows = new double[numWindows][l]; + long[] windowStartTime = new long[numWindows]; + for (int w = 0; w < numWindows; w++) { + windowStartTime[w] = timestamps.get(w * l); + for (int j = 0; j < l; j++) { + windows[w][j] = values.get(w * l + j); + } + } + + if (OUTPUT_LABEL.equals(output)) { + int[] labels; + if (METHOD_KMEANS.equals(method)) { + KMeans km = new KMeans(); + km.fit(windows, k, norm, maxIter); + labels = km.getLabels(); + } else if (METHOD_KSHAPE.equals(method)) { + KShape ks = new KShape(); + ks.fit(windows, k, norm, maxIter); + labels = ks.getLabels(); + } else if (METHOD_MEDOIDSHAPE.equals(method)) { + MedoidShape ms = new MedoidShape(); + ms.setSampleRate(sampleRate); + ms.fit(windows, k, norm, maxIter); + labels = ms.getLabels(); + } else { + throw new UDFException("Unsupported method: " + method); + } + for (int w = 0; w < numWindows; w++) { + collector.putInt(windowStartTime[w], labels[w]); + } + } else { + double[][] centroids; + if (METHOD_KMEANS.equals(method)) { + KMeans km = new KMeans(); + km.fit(windows, k, norm, maxIter); + centroids = km.getCentroids(); + } else if (METHOD_KSHAPE.equals(method)) { + KShape ks = new KShape(); + ks.fit(windows, k, norm, maxIter); + centroids = ks.getCentroids(); + } else if (METHOD_MEDOIDSHAPE.equals(method)) { + MedoidShape ms = new MedoidShape(); + ms.setSampleRate(sampleRate); + ms.fit(windows, k, norm, maxIter); + centroids = ms.getCentroids(); + } else { + throw new UDFException("Unsupported method: " + method); + } + emitConcatenatedCentroids(collector, centroids); + } + } + + private static void emitConcatenatedCentroids(PointCollector collector, double[][] centroids) + throws Exception { + long t = 0L; + for (double[] row : centroids) { + for (double v : row) { + collector.putDouble(t++, v); + } + } + } +} diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/ClusterUtils.java b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/ClusterUtils.java new file mode 100644 index 0000000000000..a646adb2bed30 --- /dev/null +++ b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/ClusterUtils.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.library.dlearn.util.cluster; + +import org.apache.commons.math3.complex.Complex; +import org.apache.commons.math3.transform.DftNormalization; +import org.apache.commons.math3.transform.FastFourierTransformer; +import org.apache.commons.math3.transform.TransformType; + +import java.util.Arrays; + +/** + * Subsequence z-normalize, Euclidean distance, and FFT-based NCC / SBD (shared by KShape and + * MedoidShape). + */ +public final class ClusterUtils { + + public static final double EPS = 1e-9; + + private static final FastFourierTransformer FFT = + new FastFourierTransformer(DftNormalization.STANDARD); + + private ClusterUtils() {} + + public static double[] maybeZNormalize(double[] a, boolean normalize) { + if (normalize) { + return zNormalize(a); + } + return Arrays.copyOf(a, a.length); + } + + public static double[] zNormalize(double[] a) { + int n = a.length; + double sum = 0.0; + for (double v : a) { + sum += v; + } + double mean = sum / n; + double var = 0.0; + for (double v : a) { + double d = v - mean; + var += d * d; + } + var /= n; + double std = Math.sqrt(Math.max(var, 0.0)); + double[] z = new double[n]; + if (std < EPS) { + return z; + } + for (int i = 0; i < n; i++) { + z[i] = (a[i] - mean) / std; + } + return z; + } + + public static double squaredEuclidean(double[] a, double[] b) { + double s = 0.0; + for (int i = 0; i < a.length; i++) { + double d = a[i] - b[i]; + s += d * d; + } + return s; + } + + public static int findLargestCluster(int[] counts) { + int best = 0; + for (int i = 1; i < counts.length; i++) { + if (counts[i] > counts[best]) { + best = i; + } + } + return best; + } + + /** + * Maximum over the normalized cross-correlation sequence (FFT); used for SBD and MedoidShape + * objective. + */ + public static double maxNcc(double[] x, double[] y) { + double[] cc = nccFft(x, y); + double max = Double.NEGATIVE_INFINITY; + for (double v : cc) { + if (v > max) { + max = v; + } + } + return max; + } + + /** SBD: 1 − max NCC (consistent with the NCC-based definition in k-Shape / FastKShape). */ + public static double shapeDistance(double[] x, double[] y) { + return 1.0 - maxNcc(x, y); + } + + public static double symmetricSbd(double[] a, double[] b) { + return 0.5 * (shapeDistance(a, b) + shapeDistance(b, a)); + } + + private static double[] nccFft(double[] x, double[] y) { + int xLen = x.length; + double den = l2Norm(x) * l2Norm(y); + if (den < 1e-9) { + den = Double.POSITIVE_INFINITY; + } + int fftSize = 1 << (32 - Integer.numberOfLeadingZeros(2 * xLen - 1)); + + Complex[] cx = new Complex[fftSize]; + Complex[] cy = new Complex[fftSize]; + for (int i = 0; i < fftSize; i++) { + cx[i] = new Complex(i < xLen ? x[i] : 0.0, 0.0); + cy[i] = new Complex(i < xLen ? y[i] : 0.0, 0.0); + } + Complex[] fx = FFT.transform(cx, TransformType.FORWARD); + Complex[] fy = FFT.transform(cy, TransformType.FORWARD); + Complex[] prod = new Complex[fftSize]; + for (int i = 0; i < fftSize; i++) { + prod[i] = fx[i].multiply(fy[i].conjugate()); + } + Complex[] ccFull = FFT.transform(prod, TransformType.INVERSE); + + double[] ccPacked = new double[2 * xLen - 1]; + int p = 0; + for (int i = fftSize - (xLen - 1); i < fftSize; i++) { + ccPacked[p++] = ccFull[i].getReal() / den; + } + for (int i = 0; i < xLen; i++) { + ccPacked[p++] = ccFull[i].getReal() / den; + } + return ccPacked; + } + + private static double l2Norm(double[] v) { + double s = 0.0; + for (double x : v) { + s += x * x; + } + return Math.sqrt(s); + } +} diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/KMeans.java b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/KMeans.java new file mode 100644 index 0000000000000..37c44eade96bf --- /dev/null +++ b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/KMeans.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.library.dlearn.util.cluster; + +import java.util.Arrays; + +/** + * Univariate subsequence k-means (Lloyd); optionally z-normalize, then cluster in Euclidean space. + */ +public class KMeans { + + private double[][] centroids; + private int[] labels; + + public void fit(double[][] samples, int k, boolean normalize, int maxIterations) { + validate(samples, k, maxIterations); + int n = samples.length; + int dim = samples[0].length; + + double[][] z = new double[n][dim]; + for (int i = 0; i < n; i++) { + z[i] = ClusterUtils.maybeZNormalize(samples[i], normalize); + } + + centroids = new double[k][dim]; + for (int c = 0; c < k; c++) { + System.arraycopy(z[c], 0, centroids[c], 0, dim); + } + + labels = new int[n]; + Arrays.fill(labels, -1); + + for (int iter = 0; iter < maxIterations; iter++) { + double[][] prevCentroids = new double[k][dim]; + for (int c = 0; c < k; c++) { + System.arraycopy(centroids[c], 0, prevCentroids[c], 0, dim); + } + + boolean changed = false; + for (int i = 0; i < n; i++) { + int best = 0; + double bestDist = Double.POSITIVE_INFINITY; + for (int c = 0; c < k; c++) { + double d = ClusterUtils.squaredEuclidean(z[i], centroids[c]); + if (d < bestDist) { + bestDist = d; + best = c; + } + } + if (labels[i] != best) { + labels[i] = best; + changed = true; + } + } + + double[][] newCentroids = new double[k][dim]; + int[] counts = new int[k]; + for (int i = 0; i < n; i++) { + int c = labels[i]; + counts[c]++; + for (int d = 0; d < dim; d++) { + newCentroids[c][d] += z[i][d]; + } + } + for (int c = 0; c < k; c++) { + if (counts[c] == 0) { + int donor = ClusterUtils.findLargestCluster(counts); + System.arraycopy(prevCentroids[donor], 0, centroids[c], 0, dim); + for (int d = 0; d < dim; d++) { + centroids[c][d] += (d == 0 ? 1e-4 : -1e-4); + } + } else { + for (int d = 0; d < dim; d++) { + centroids[c][d] = newCentroids[c][d] / counts[c]; + } + } + } + + if (!changed) { + break; + } + } + } + + public double[][] getCentroids() { + return centroids; + } + + public int[] getLabels() { + return labels; + } + + private static void validate(double[][] samples, int k, int maxIterations) { + if (samples == null || samples.length == 0) { + throw new IllegalArgumentException("samples must be non-empty."); + } + if (k < 2 || k > samples.length) { + throw new IllegalArgumentException("k must satisfy 2 <= k <= samples.length."); + } + if (maxIterations < 1) { + throw new IllegalArgumentException("maxIterations must be at least 1."); + } + int dim = samples[0].length; + if (dim == 0) { + throw new IllegalArgumentException("sample dimension must be positive."); + } + for (double[] row : samples) { + if (row == null || row.length != dim) { + throw new IllegalArgumentException("All samples must have the same length."); + } + } + } +} diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/KShape.java b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/KShape.java new file mode 100644 index 0000000000000..315e1c51d2f06 --- /dev/null +++ b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/KShape.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.library.dlearn.util.cluster; + +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.linear.SingularValueDecomposition; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * k-Shape: assignment uses {@link ClusterUtils#shapeDistance} (SBD = 1 − max NCC); centroids are + * the first right singular vector of the cluster matrix from SVD, sign correction, then z-normalize + * or L2 normalization. + */ +public class KShape { + + private double[][] centroids; + private int[] labels; + + public void fit(double[][] samples, int k, boolean normalize, int maxIterations) { + validate(samples, k, maxIterations); + int n = samples.length; + int dim = samples[0].length; + + double[][] z = new double[n][dim]; + for (int i = 0; i < n; i++) { + z[i] = ClusterUtils.maybeZNormalize(samples[i], normalize); + } + + centroids = new double[k][dim]; + for (int c = 0; c < k; c++) { + System.arraycopy(z[c], 0, centroids[c], 0, dim); + } + + labels = new int[n]; + Arrays.fill(labels, -1); + + for (int iter = 0; iter < maxIterations; iter++) { + double[][] prevCentroids = new double[k][dim]; + for (int c = 0; c < k; c++) { + System.arraycopy(centroids[c], 0, prevCentroids[c], 0, dim); + } + + boolean changed = false; + for (int i = 0; i < n; i++) { + int best = 0; + double bestDist = Double.POSITIVE_INFINITY; + for (int c = 0; c < k; c++) { + double d = ClusterUtils.shapeDistance(z[i], centroids[c]); + if (d < bestDist) { + bestDist = d; + best = c; + } + } + if (labels[i] != best) { + labels[i] = best; + changed = true; + } + } + + int[] counts = new int[k]; + @SuppressWarnings("unchecked") + List[] byCluster = new List[k]; + for (int c = 0; c < k; c++) { + byCluster[c] = new ArrayList<>(); + } + for (int i = 0; i < n; i++) { + int c = labels[i]; + counts[c]++; + byCluster[c].add(z[i]); + } + + for (int c = 0; c < k; c++) { + if (counts[c] == 0) { + int donor = ClusterUtils.findLargestCluster(counts); + System.arraycopy(prevCentroids[donor], 0, centroids[c], 0, dim); + } else { + List members = byCluster[c]; + double[][] mat = new double[members.size()][dim]; + for (int i = 0; i < members.size(); i++) { + mat[i] = members.get(i); + } + centroids[c] = centroidFromSvd(mat, normalize); + } + } + + if (!changed) { + break; + } + } + } + + public double[][] getCentroids() { + return centroids; + } + + public int[] getLabels() { + return labels; + } + + private static double[] centroidFromSvd(double[][] members, boolean zNormalizeCentroid) { + int m = members.length; + int dim = members[0].length; + if (m == 1) { + double[] u = Arrays.copyOf(members[0], dim); + return zNormalizeCentroid ? ClusterUtils.zNormalize(u) : l2Unit(u); + } + RealMatrix y = MatrixUtils.createRealMatrix(members); + SingularValueDecomposition svd = new SingularValueDecomposition(y); + RealMatrix v = svd.getV(); + RealVector col0 = v.getColumnVector(0); + double[] r = col0.toArray(); + double sumDot = 0.0; + for (double[] row : members) { + sumDot += dot(row, r); + } + if (sumDot < 0) { + for (int i = 0; i < r.length; i++) { + r[i] = -r[i]; + } + } + return zNormalizeCentroid ? ClusterUtils.zNormalize(r) : l2Unit(r); + } + + private static double dot(double[] a, double[] b) { + double s = 0.0; + for (int i = 0; i < a.length; i++) { + s += a[i] * b[i]; + } + return s; + } + + private static double[] l2Unit(double[] v) { + double s = 0.0; + for (double x : v) { + s += x * x; + } + s = Math.sqrt(s); + if (s < ClusterUtils.EPS) { + return new double[v.length]; + } + double[] o = new double[v.length]; + for (int i = 0; i < v.length; i++) { + o[i] = v[i] / s; + } + return o; + } + + private static void validate(double[][] samples, int k, int maxIterations) { + if (samples == null || samples.length == 0) { + throw new IllegalArgumentException("samples must be non-empty."); + } + if (k < 2 || k > samples.length) { + throw new IllegalArgumentException("k must satisfy 2 <= k <= samples.length."); + } + if (maxIterations < 1) { + throw new IllegalArgumentException("maxIterations must be at least 1."); + } + int dim = samples[0].length; + if (dim == 0) { + throw new IllegalArgumentException("sample dimension must be positive."); + } + for (double[] row : samples) { + if (row == null || row.length != dim) { + throw new IllegalArgumentException("All samples must have the same length."); + } + } + } +} diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/MedoidShape.java b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/MedoidShape.java new file mode 100644 index 0000000000000..31f9d8c8df1a0 --- /dev/null +++ b/library-udf/src/main/java/org/apache/iotdb/library/dlearn/util/cluster/MedoidShape.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.library.dlearn.util.cluster; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Random; +import java.util.Set; + +/** + * Coarse clustering: {@link KMeans} uses {@code min(2k, n)} clusters (n = number of windows); + * greedy fastKShape picks k representatives; both labels and the objective use {@link + * ClusterUtils#maxNcc}. + */ +public class MedoidShape { + + private double sampleRate = 0.3; + private Random random = new Random(); + + /** Overrides the RNG used for greedy sampling (default is {@link Random#Random()}). */ + public void setRandom(Random random) { + this.random = Objects.requireNonNull(random); + } + + private double[][] centroids; + private int[] labels; + + public void setSampleRate(double sampleRate) { + if (sampleRate <= 0 || sampleRate > 1.0) { + throw new IllegalArgumentException("sampleRate must be in (0, 1]."); + } + this.sampleRate = sampleRate; + } + + public double getSampleRate() { + return sampleRate; + } + + public void fit(double[][] samples, int k, boolean normalize, int maxIterations) { + validate(samples, k, maxIterations); + int n = samples.length; + int dim = samples[0].length; + + int coarseK = Math.min(2 * k, n); + + double[][] x = new double[n][dim]; + for (int i = 0; i < n; i++) { + x[i] = ClusterUtils.maybeZNormalize(samples[i], normalize); + } + + KMeans coarse = new KMeans(); + coarse.fit(x, coarseK, false, maxIterations); + double[][] euclideanCentroids = coarse.getCentroids(); + int[] kmLabels = coarse.getLabels(); + long[] clusterSize = new long[coarseK]; + for (int lb : kmLabels) { + clusterSize[lb]++; + } + + centroids = fastKShape(x, k, sampleRate, dim, euclideanCentroids, clusterSize, random); + + labels = new int[n]; + for (int i = 0; i < n; i++) { + double maxNcc = Double.NEGATIVE_INFINITY; + int label = -1; + for (int j = 0; j < k; j++) { + double cur = ClusterUtils.maxNcc(x[i], centroids[j]); + if (cur > maxNcc) { + maxNcc = cur; + label = j; + } + } + labels[i] = label; + } + } + + public double[][] getCentroids() { + return centroids; + } + + public int[] getLabels() { + return labels; + } + + private static double[][] fastKShape( + double[][] x, + int k, + double r, + int dim, + double[][] euclideanCentroids, + long[] clusterSize, + Random rnd) { + int n = x.length; + if (n <= k) { + double[][] out = new double[k][dim]; + for (int i = 0; i < n; i++) { + out[i] = Arrays.copyOf(x[i], dim); + } + for (int i = n; i < k; i++) { + out[i] = Arrays.copyOf(x[n - 1], dim); + } + return out; + } + + List picked = new ArrayList<>(); + Set coresetIdx = new HashSet<>(); + + for (int round = 0; round < k; round++) { + List pool = new ArrayList<>(); + for (int i = 0; i < n; i++) { + if (!coresetIdx.contains(i)) { + pool.add(i); + } + } + if (pool.isEmpty()) { + throw new IllegalStateException("fastKShape: empty candidate pool."); + } + int sampleCount = Math.max(1, (int) (r * n)); + sampleCount = Math.min(sampleCount, pool.size()); + Collections.shuffle(pool, rnd); + List sampleIdx = pool.subList(0, sampleCount); + + double maxDelta = Double.NEGATIVE_INFINITY; + double[] bestSeg = null; + int bestIdx = -1; + + for (int idx : sampleIdx) { + double[] seq = x[idx]; + picked.add(seq); + double delta = evaluateAim(picked, euclideanCentroids, clusterSize); + picked.remove(picked.size() - 1); + if (delta > maxDelta) { + maxDelta = delta; + bestSeg = Arrays.copyOf(seq, dim); + bestIdx = idx; + } + } + + if (bestSeg == null) { + throw new IllegalStateException("fastKShape: no candidate selected."); + } + picked.add(bestSeg); + coresetIdx.add(bestIdx); + } + + double[][] out = new double[k][dim]; + for (int i = 0; i < k; i++) { + out[i] = picked.get(i); + } + return out; + } + + private static double evaluateAim( + List curCentroids, double[][] euclideanCentroids, long[] clusterSize) { + double res = 0.0; + for (int i = 0; i < euclideanCentroids.length; i++) { + double maxNcc = Double.NEGATIVE_INFINITY; + for (double[] cur : curCentroids) { + double n = ClusterUtils.maxNcc(cur, euclideanCentroids[i]); + if (n > maxNcc) { + maxNcc = n; + } + } + res += maxNcc * clusterSize[i]; + } + return res; + } + + private static void validate(double[][] samples, int k, int maxIterations) { + if (samples == null || samples.length == 0) { + throw new IllegalArgumentException("samples must be non-empty."); + } + if (k < 2) { + throw new IllegalArgumentException("k must be at least 2."); + } + if (k > samples.length) { + throw new IllegalArgumentException("k must not exceed the number of samples."); + } + if (maxIterations < 1) { + throw new IllegalArgumentException("maxIterations must be at least 1."); + } + int dim = samples[0].length; + if (dim == 0) { + throw new IllegalArgumentException("sample dimension must be positive."); + } + for (double[] row : samples) { + if (row == null || row.length != dim) { + throw new IllegalArgumentException("All samples must have the same length."); + } + } + } +} diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dprofile/UDAFQuantile.java b/library-udf/src/main/java/org/apache/iotdb/library/dprofile/UDAFQuantile.java index 047beafe1fc6a..0518e7a3d768a 100644 --- a/library-udf/src/main/java/org/apache/iotdb/library/dprofile/UDAFQuantile.java +++ b/library-udf/src/main/java/org/apache/iotdb/library/dprofile/UDAFQuantile.java @@ -65,59 +65,65 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati @Override public void transform(Row row, PointCollector collector) throws Exception { - double res = Util.getValueAsDouble(row); - sketch.update(dataToLong(res)); + final long encoded; + switch (dataType) { + case INT32: + encoded = row.getInt(0); + break; + case INT64: + encoded = row.getLong(0); + break; + default: + encoded = dataToLong(Util.getValueAsDouble(row)); + break; + } + sketch.update(encoded); } @Override public void terminate(PointCollector collector) throws Exception { - long result = sketch.findMinValueWithRank((long) (rank * sketch.getN())); - double res = longToResult(result); + long n = sketch.getN(); + // Nearest-rank: k-th smallest uses getApproxRank (strictly-less-than count) in [0, n-1]; + // rank=1 must map to k=n-1, not k=n which is unreachable and can overshoot the max sample. + long k = 0; + if (n > 0) { + k = (long) Math.ceil(rank * n) - 1; + if (k < 0) { + k = 0; + } else if (k >= n) { + k = n - 1; + } + } + long result = sketch.findMinValueWithRank(k); switch (dataType) { case INT32: - collector.putInt(0, (int) res); + collector.putInt(0, (int) result); break; case INT64: - collector.putLong(0, (long) res); + collector.putLong(0, result); break; case FLOAT: - collector.putFloat(0, (float) res); + collector.putFloat(0, (float) longToResult(result)); break; case DOUBLE: - collector.putDouble(0, res); + collector.putDouble(0, longToResult(result)); break; - case TIMESTAMP: - case DATE: - case TEXT: - case STRING: - case BLOB: - case BOOLEAN: default: break; } } - private long dataToLong(Object data) { - long result; + private long dataToLong(double res) { switch (dataType) { - case INT32: - return (int) data; case FLOAT: - result = Float.floatToIntBits((float) data); - return (float) data >= 0f ? result : result ^ Long.MAX_VALUE; - case INT64: - return (long) data; + float f = (float) res; + long flBits = Float.floatToIntBits(f); + return f >= 0f ? flBits : flBits ^ Long.MAX_VALUE; case DOUBLE: - result = Double.doubleToLongBits((double) data); - return (double) data >= 0d ? result : result ^ Long.MAX_VALUE; - case BLOB: - case BOOLEAN: - case STRING: - case TEXT: - case DATE: - case TIMESTAMP: + long d = Double.doubleToLongBits(res); + return res >= 0d ? d : d ^ Long.MAX_VALUE; default: - return (long) data; + return (long) res; } } @@ -129,16 +135,8 @@ private double longToResult(long result) { case DOUBLE: result = (result >>> 63) == 0 ? result : result ^ Long.MAX_VALUE; return Double.longBitsToDouble(result); - case INT64: - case INT32: - case DATE: - case TEXT: - case STRING: - case BOOLEAN: - case BLOB: - case TIMESTAMP: default: - return (result); + return (double) result; } } } diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/ExactOrderStatistics.java b/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/ExactOrderStatistics.java index e1f0baa7c0602..47ca5e2b12fd8 100644 --- a/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/ExactOrderStatistics.java +++ b/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/ExactOrderStatistics.java @@ -31,7 +31,14 @@ import java.io.IOException; import java.util.NoSuchElementException; -/** Util for computing median, MAD, percentile. */ +/** + * Util for computing median, MAD, percentile. + * + *

Percentile / quantile ({@link #getPercentile}) uses discrete nearest-rank: for sorted + * size {@code n} and {@code phi} in (0, 1], take 1-based rank {@code k = ceil(n * phi)} and 0-based + * index {@code k - 1}, clamped to {@code [0, n - 1]}. No interpolation; {@code phi = 0.5} is not + * required to match {@link #getMedian}. + */ public class ExactOrderStatistics { private final Type dataType; @@ -55,12 +62,6 @@ public ExactOrderStatistics(Type type) throws UDFInputSeriesDataTypeNotValidExce case DOUBLE: doubleArrayList = new DoubleArrayList(); break; - case STRING: - case TEXT: - case BOOLEAN: - case BLOB: - case DATE: - case TIMESTAMP: default: // This will not happen. throw new UDFInputSeriesDataTypeNotValidException( @@ -88,12 +89,6 @@ public void insert(Row row) throws UDFInputSeriesDataTypeNotValidException, IOEx doubleArrayList.add(vd); } break; - case DATE: - case TIMESTAMP: - case BLOB: - case BOOLEAN: - case TEXT: - case STRING: default: // This will not happen. throw new UDFInputSeriesDataTypeNotValidException( @@ -111,12 +106,6 @@ public double getMedian() throws UDFInputSeriesDataTypeNotValidException { return getMedian(floatArrayList); case DOUBLE: return getMedian(doubleArrayList); - case TEXT: - case STRING: - case BOOLEAN: - case BLOB: - case TIMESTAMP: - case DATE: default: // This will not happen. throw new UDFInputSeriesDataTypeNotValidException( @@ -199,12 +188,6 @@ public double getMad() throws UDFInputSeriesDataTypeNotValidException { return getMad(floatArrayList); case DOUBLE: return getMad(doubleArrayList); - case TIMESTAMP: - case DATE: - case BLOB: - case BOOLEAN: - case STRING: - case TEXT: default: // This will not happen. throw new UDFInputSeriesDataTypeNotValidException( @@ -251,12 +234,18 @@ public static double getMad(LongArrayList nums) { } } + /** Discrete nearest-rank index into sorted data of length {@code n}; see class Javadoc. */ + private static int discreteNearestRankIndex(int n, double phi) { + int idx = (int) Math.ceil(n * phi) - 1; + return Math.max(0, Math.min(n - 1, idx)); + } + public static float getPercentile(FloatArrayList nums, double phi) { if (nums.isEmpty()) { throw new NoSuchElementException(); } else { nums.sortThis(); - return nums.get((int) Math.ceil(nums.size() * phi)); + return nums.get(discreteNearestRankIndex(nums.size(), phi)); } } @@ -265,7 +254,7 @@ public static double getPercentile(DoubleArrayList nums, double phi) { throw new NoSuchElementException(); } else { nums.sortThis(); - return nums.get((int) Math.ceil(nums.size() * phi)); + return nums.get(discreteNearestRankIndex(nums.size(), phi)); } } @@ -279,12 +268,6 @@ public String getPercentile(double phi) throws UDFInputSeriesDataTypeNotValidExc return Float.toString(getPercentile(floatArrayList, phi)); case DOUBLE: return Double.toString(getPercentile(doubleArrayList, phi)); - case STRING: - case TEXT: - case BOOLEAN: - case BLOB: - case DATE: - case TIMESTAMP: default: // This will not happen. throw new UDFInputSeriesDataTypeNotValidException( @@ -297,7 +280,7 @@ public static int getPercentile(IntArrayList nums, double phi) { throw new NoSuchElementException(); } else { nums.sortThis(); - return nums.get((int) Math.ceil(nums.size() * phi)); + return nums.get(discreteNearestRankIndex(nums.size(), phi)); } } @@ -306,7 +289,7 @@ public static long getPercentile(LongArrayList nums, double phi) { throw new NoSuchElementException(); } else { nums.sortThis(); - return nums.get((int) Math.ceil(nums.size() * phi)); + return nums.get(discreteNearestRankIndex(nums.size(), phi)); } } } diff --git a/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/GKArray.java b/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/GKArray.java index 1870bdfb7a4c2..7dbcc934e7860 100644 --- a/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/GKArray.java +++ b/library-udf/src/main/java/org/apache/iotdb/library/dprofile/util/GKArray.java @@ -124,6 +124,19 @@ private void compress(List additionalEntries) { i++; + } else if (i >= additionalEntries.size()) { + // Only sketch entries left (must check before comparing additionalEntries.get(i)). + if (j + 1 < entries.size() + && entries.get(j).g + entries.get(j + 1).g + entries.get(j + 1).delta + <= removalThreshold) { + // Removable from sketch. + entries.get(j + 1).g += entries.get(j).g; + } else { + mergedEntries.add(entries.get(j)); + } + + j++; + } else if (additionalEntries.get(i).v < entries.get(j).v) { if (additionalEntries.get(i).g + entries.get(j).g + entries.get(j).delta <= removalThreshold) { @@ -136,7 +149,7 @@ private void compress(List additionalEntries) { i++; - } else { // the same as i == additionalEntries.size() + } else { if (j + 1 < entries.size() && entries.get(j).g + entries.get(j + 1).g + entries.get(j + 1).delta <= removalThreshold) {