Skip to content

Commit

Permalink
Various fixes and improvements for MultiThreadedBenchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
zemei authored and Tibor Mezei committed Feb 6, 2021
1 parent 6c26228 commit 69911b9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ public static void collectMemoryInfo(Metrics metrics) {
MemoryUsage heap = memBean.getHeapMemoryUsage();
MemoryUsage nonHeap = memBean.getNonHeapMemoryUsage();

long heapCommitted = heap.getCommitted();
long nonHeapCommitted = nonHeap.getCommitted();
long heapUsed = heap.getUsed();
long nonHeapUsed = nonHeap.getUsed();
getProcessInfo(metrics);

metrics.addMetric("Heap", heapCommitted, "bytes");
metrics.addMetric("NonHeap", nonHeapCommitted, "bytes");
metrics.addMetric("Heap", heapUsed, "bytes");
metrics.addMetric("NonHeap", nonHeapUsed, "bytes");
int gpuCount = Device.getGpuCount();

// CudaUtils.getGpuMemory() will allocates memory on GPUs if CUDA runtime is not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,27 @@ public static void main(String[] args) {
@Override
public Object predict(Arguments arguments, Metrics metrics, int iteration)
throws IOException, ModelException {

MemoryTrainingListener.collectMemoryInfo(metrics); // Measure memory before loading model

Object inputData = arguments.getInputData();
ZooModel<?, ?> model = loadModel(arguments, metrics);

int numOfThreads = arguments.getThreads();
int delay = arguments.getDelay();
AtomicInteger counter = new AtomicInteger(iteration);
AtomicInteger counter = new AtomicInteger(iteration + 1);
logger.info("Multithreaded inference with {} threads.", numOfThreads);

List<PredictorCallable> callables = new ArrayList<>(numOfThreads);
for (int i = 0; i < numOfThreads; ++i) {
List<PredictorCallable> callables = new ArrayList<>(numOfThreads + 1);
for (int i = 0; i < numOfThreads + 1; ++i) {
callables.add(new PredictorCallable(model, inputData, metrics, counter, i, i == 0));
}

Object classification = null;
ExecutorService executorService = Executors.newFixedThreadPool(numOfThreads);
ExecutorService executorService = Executors.newFixedThreadPool(numOfThreads + 1);

MemoryTrainingListener.collectMemoryInfo(metrics); // Measure memory before worker kickoff

int successThreads = 0;
try {
metrics.addMetric("mt_start", System.currentTimeMillis(), "mills");
Expand Down Expand Up @@ -93,7 +99,7 @@ public Object predict(Arguments arguments, Metrics metrics, int iteration)
}

model.close();
if (successThreads != numOfThreads) {
if (successThreads != numOfThreads + 1) {
logger.error("Only {}/{} threads finished.", successThreads, numOfThreads);
return null;
}
Expand Down Expand Up @@ -143,21 +149,25 @@ public Object call() throws Exception {
Object result = null;
int count = 0;
int remaining;
while ((remaining = counter.decrementAndGet()) > 0 || result == null) {
try {
result = predictor.predict(inputData);
} catch (Exception e) {
// stop immediately when we find any exception
counter.set(0);
throw e;
}
if (collectMemory) {
if (collectMemory) {
result = "MemoryCollector";
while (counter.get() > 0) {
MemoryTrainingListener.collectMemoryInfo(metrics);
}
int processed = total - remaining + 1;
logger.trace("Worker-{}: {} iteration finished.", workerId, ++count);
if (processed % steps == 0 || processed == total) {
logger.info("Completed {} requests", processed);
} else {
while ((remaining = counter.decrementAndGet()) > 0 || result == null) {
try {
result = predictor.predict(inputData);
} catch (Exception e) {
// stop immediately when we find any exception
counter.set(0);
throw e;
}
int processed = total - remaining + 1;
logger.trace("Worker-{}: {} iteration finished.", workerId, ++count);
if (processed % steps == 0 || processed == total) {
logger.info("Completed {} requests", processed);
}
}
}
logger.debug("Worker-{}: finished.", workerId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,23 @@ public final boolean runBenchmark(String[] args) {
postP50, postP90, postP99));

if (Boolean.getBoolean("collect-memory")) {
float heapBeforeModel = metrics.getMetric("Heap").get(0).getValue().longValue();
float heapBeforeInference = metrics.getMetric("Heap").get(1).getValue().longValue();
float heap = metrics.percentile("Heap", 90).getValue().longValue();
float nonHeap = metrics.percentile("NonHeap", 90).getValue().longValue();
float cpu = metrics.percentile("cpu", 90).getValue().longValue();
float rssBeforeModel = metrics.getMetric("rss").get(0).getValue().longValue();
float rssBeforeInference = metrics.getMetric("rss").get(1).getValue().longValue();
float rss = metrics.percentile("rss", 90).getValue().longValue();

logger.info(String.format("heap P90: %.3f", heap));
logger.info(String.format("nonHeap P90: %.3f", nonHeap));
logger.info(String.format("cpu P90: %.3f", cpu));
logger.info(String.format("rss P90: %.3f", rss));
logger.info(String.format("heap (base): %.3f MB", heapBeforeModel/(1024*1024)));
logger.info(String.format("heap (model): %.3f MB", (heapBeforeInference - heapBeforeModel)/(1024*1024)));
logger.info(String.format("heap (inference) P90: %.3f MB", (heap - heapBeforeInference)/(1024*1024)));
logger.info(String.format("nonHeap P90: %.3f MB", nonHeap/(1024*1024)));
logger.info(String.format("cpu P90: %.3f %%", cpu));
logger.info(String.format("rss (base): %.3f MB", rssBeforeModel/(1024*1024)));
logger.info(String.format("rss (model): %.3f MB", (rssBeforeInference - rssBeforeModel)/(1024*1024)));
logger.info(String.format("rss (inference) P90: %.3f MB", (rss - rssBeforeInference)/(1024*1024)));
}
}
MemoryTrainingListener.dumpMemoryInfo(metrics, arguments.getOutputDir());
Expand Down

0 comments on commit 69911b9

Please sign in to comment.