-
Notifications
You must be signed in to change notification settings - Fork 0
/
Main.java
44 lines (32 loc) · 1.47 KB
/
Main.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.translate.TranslateException;
import java.io.IOException;
public class Main {
public static NDManager manager;
public static void main(String[] args) throws IOException, TranslateException {
manager = NDManager.newBaseManager(Functions.tryGpu(0));
Seq2SeqEncoder encoder = new Seq2SeqEncoder(10, 8, 16, 2, 0);
NDArray X = manager.zeros(new Shape(4, 7));
NDList outputState =
encoder.forward(new ParameterStore(manager, false), new NDList(X), false);
NDArray output = outputState.get(0);
System.out.println(output.getShape());
NDList state = outputState.subNDList(1);
System.out.println(state.size());
System.out.println(state.get(0).getShape());
Seq2SeqDecoder decoder = new Seq2SeqDecoder(10, 8, 16, 2, 0);
state = decoder.beginState(outputState);
outputState =
decoder.forward(
new ParameterStore(manager, false), new NDList(X).addAll(state), false);
output = outputState.get(0);
System.out.println(output.getShape());
state = outputState.subNDList(1);
System.out.println(state.size());
System.out.println(state.get(0).getShape());
X = manager.create(new int[][] {{1, 2, 3}, {4, 5, 6}});
System.out.println(X.sequenceMask(manager.create(new int[] {1, 2})));
}
}