From 8a33322bc2d15b74e3c94ad86229c8d303334da7 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 6 Jul 2024 12:07:20 -0700 Subject: [PATCH] [timeseries] Fixes contentLength issue for inference Address part of issue #3271 --- .../java/ai/djl/examples/training/TrainTimeSeries.java | 8 +++++--- .../timeseries/translator/BaseTimeSeriesTranslator.java | 6 ++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java index b9abb016e5f..66f734e9c7e 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java @@ -120,10 +120,10 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans inputShapes[6] = new Shape( 1, - predictionLength, + contextLength, TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1); - inputShapes[7] = new Shape(1, predictionLength); - inputShapes[8] = new Shape(1, predictionLength); + inputShapes[7] = new Shape(1, contextLength); + inputShapes[8] = new Shape(1, contextLength); trainer.initialize(inputShapes); EasyTrain.fit(trainer, arguments.getEpoch(), trainSet, null); @@ -147,6 +147,7 @@ public static Map predict(String outputDir) Map arguments = new ConcurrentHashMap<>(); arguments.put("prediction_length", predictionLength); + arguments.put("context_length", predictionNetwork.getContextLength()); arguments.put("freq", freq); arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false); arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), true); @@ -239,6 +240,7 @@ private static DeepARNetwork getDeepARModel( .setFreq(freq) .setPredictionLength(predictionLength) .optDistrOutput(distributionOutput) + .optContextLength(8) .optUseFeatStaticCat(true); return training ? builder.buildTrainingNetwork() : builder.buildPredictionNetwork(); } diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/translator/BaseTimeSeriesTranslator.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/translator/BaseTimeSeriesTranslator.java index c84e874a352..d34bcfb6fb3 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/translator/BaseTimeSeriesTranslator.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/translator/BaseTimeSeriesTranslator.java @@ -39,8 +39,7 @@ protected BaseTimeSeriesTranslator(BaseBuilder builder) { this.batchifier = builder.batchifier; this.freq = builder.freq; this.predictionLength = builder.predictionLength; - // TODO: for inferring - this.contextLength = builder.predictionLength; + this.contextLength = builder.contextLength; } /** {@inheritDoc} */ @@ -57,6 +56,7 @@ public Batchifier getBatchifier() { public abstract static class BaseBuilder> { protected Batchifier batchifier = Batchifier.STACK; protected int predictionLength; + protected int contextLength; protected String freq; @@ -82,6 +82,8 @@ protected void configPreProcess(Map arguments) { throw new IllegalArgumentException( "The value of `prediction_length` should be > 0"); } + this.contextLength = + ArgumentsUtil.intValue(arguments, "context_length", predictionLength); if (arguments.containsKey("batchifier")) { batchifier = Batchifier.fromString((String) arguments.get("batchifier")); }