Skip to content

Commit

Permalink
[rust] Avoid panic in error case (#3133)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Apr 26, 2024
1 parent 6efe660 commit ec89a66
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
9 changes: 6 additions & 3 deletions extensions/tokenizers/rust/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod distilbert;
use crate::ndarray::as_data_type;
use crate::{cast_handle, to_handle, to_string_array};
use bert::{BertConfig, BertModel};
use candle_core::DType;
use candle_core::{DType, Error};
use candle_core::{Device, Result, Tensor};
use candle_nn::VarBuilder;
use distilbert::{DistilBertConfig, DistilBertModel};
Expand Down Expand Up @@ -43,7 +43,10 @@ fn load_model<'local>(

// Load config
let config: String = std::fs::read_to_string(model_path.join("config.json"))?;
let config: Config = serde_json::from_str(&config).unwrap();
let config: Config = match serde_json::from_str(&config) {
Ok(conf) => conf,
Err(err) => return Err(Error::wrap(err)),
};

// Get candle device
let device = if candle_core::utils::cuda_is_available() {
Expand All @@ -55,7 +58,7 @@ fn load_model<'local>(
}?;

// Get candle dtype
let dtype = as_data_type(dtype).unwrap();
let dtype = as_data_type(dtype)?;

let safetensors_path = model_path.join("model.safetensors");
let vb = if safetensors_path.exists() {
Expand Down
36 changes: 22 additions & 14 deletions extensions/tokenizers/rust/src/ndarray/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,22 +295,30 @@ fn as_device<'local>(env: &mut JNIEnv<'local>, device_type: JString, _: usize) -
match device_type.as_str() {
"cpu" => Ok(Device::Cpu),
"gpu" => {
let mut device = CUDA_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_cuda(0).unwrap();
*device = Some(d.clone());
Ok(d)
if candle_core::utils::cuda_is_available() {
let mut device = CUDA_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_cuda(0).unwrap();
*device = Some(d.clone());
Ok(d)
} else {
Err(Error::Msg(String::from("CUDA is not available.")))
}
}
"mps" => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).unwrap();
*device = Some(d.clone());
Ok(d)
if candle_core::utils::metal_is_available() {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).unwrap();
*device = Some(d.clone());
Ok(d)
} else {
Err(Error::Msg(String::from("metal is not available.")))
}
}
_ => Err(Error::Msg(format!("Invalid device type: {}", device_type))),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
"Model directory doesn't exist: " + modelPath.toAbsolutePath());
}
modelDir = modelPath.toAbsolutePath();
Path config = modelDir.resolve("config.json");
if (!Files.isRegularFile(config)) {
throw new FileNotFoundException("config.json file not found");
}
Path file = modelDir.resolve("model.safetensors");
if (!Files.isRegularFile(file)) {
throw new FileNotFoundException("model.safetensors file not found");
}
long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal());
block = new RsSymbolBlock((RsNDManager) manager, handle);
}
Expand Down

0 comments on commit ec89a66

Please sign in to comment.