Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: get stuck when load extension in the concurrency environment #184

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 127 additions & 27 deletions dubbo/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use crate::{
};
use dubbo_base::{extension_param::ExtensionType, url::UrlParam, StdError, Url};
use dubbo_logger::tracing::{error, info};
use std::{future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::sync::oneshot;
use tokio::sync::{oneshot, RwLock};

pub static EXTENSIONS: once_cell::sync::Lazy<ExtensionDirectoryCommander> =
once_cell::sync::Lazy::new(|| ExtensionDirectory::init());
Expand All @@ -41,13 +42,11 @@ impl ExtensionDirectory {
let mut extension_directory = ExtensionDirectory::default();

// register static registry extension
let _ = extension_directory
.register(
StaticRegistry::name(),
StaticRegistry::convert_to_extension_factories(),
ExtensionType::Registry,
)
.await;
let _ = extension_directory.register(
StaticRegistry::name(),
StaticRegistry::convert_to_extension_factories(),
ExtensionType::Registry,
);

while let Some(extension_opt) = rx.recv().await {
match extension_opt {
Expand All @@ -57,20 +56,19 @@ impl ExtensionDirectory {
extension_type,
tx,
) => {
let result = extension_directory
.register(extension_name, extension_factories, extension_type)
.await;
let result = extension_directory.register(
extension_name,
extension_factories,
extension_type,
);
let _ = tx.send(result);
}
ExtensionOpt::Remove(extension_name, extension_type, tx) => {
let result = extension_directory
.remove(extension_name, extension_type)
.await;
let result = extension_directory.remove(extension_name, extension_type);
let _ = tx.send(result);
}
ExtensionOpt::Load(url, extension_type, tx) => {
let result = extension_directory.load(url, extension_type).await;
let _ = tx.send(result);
let _ = extension_directory.load(url, extension_type, tx);
}
}
}
Expand All @@ -79,7 +77,7 @@ impl ExtensionDirectory {
ExtensionDirectoryCommander { sender: tx }
}

async fn register(
fn register(
&mut self,
extension_name: String,
extension_factories: ExtensionFactories,
Expand All @@ -89,47 +87,149 @@ impl ExtensionDirectory {
ExtensionType::Registry => match extension_factories {
ExtensionFactories::RegistryExtensionFactory(registry_extension_factory) => {
self.registry_extension_loader
.register(extension_name, registry_extension_factory)
.await;
.register(extension_name, registry_extension_factory);
Ok(())
}
},
}
}

async fn remove(
fn remove(
&mut self,
extension_name: String,
extension_type: ExtensionType,
) -> Result<(), StdError> {
match extension_type {
ExtensionType::Registry => {
self.registry_extension_loader.remove(extension_name).await;
self.registry_extension_loader.remove(extension_name);
Ok(())
}
}
}

async fn load(
fn load(
&mut self,
url: Url,
extension_type: ExtensionType,
) -> Result<Extensions, StdError> {
callback: oneshot::Sender<Result<Extensions, StdError>>,
) {
match extension_type {
ExtensionType::Registry => {
let extension = self.registry_extension_loader.load(&url).await;
let extension = self.registry_extension_loader.load(url);
match extension {
Ok(extension) => Ok(Extensions::Registry(extension)),
Ok(mut extension) => {
tokio::spawn(async move {
let extension = extension.resolve().await;
match extension {
Ok(extension) => {
let _ = callback.send(Ok(Extensions::Registry(extension)));
}
Err(err) => {
error!("load extension failed: {}", err);
let _ = callback.send(Err(err));
}
}
});
}
Err(err) => {
error!("load extension failed: {}", err);
Err(err)
let _ = callback.send(Err(err));
}
}
}
}
}
}

type ExtensionCreator<T> = Box<
dyn Fn(Url) -> Pin<Box<dyn Future<Output = Result<T, StdError>> + Send + 'static>>
+ Send
+ Sync
+ 'static,
>;
pub(crate) struct ExtensionPromiseResolver<T> {
resolved_data: Option<T>,
creator: ExtensionCreator<T>,
url: Url,
}

impl<T> ExtensionPromiseResolver<T>
where
T: Send + Clone + 'static,
{
fn new(creator: ExtensionCreator<T>, url: Url) -> Self {
ExtensionPromiseResolver {
resolved_data: None,
creator,
url,
}
}

fn resolved_data(&self) -> Option<T> {
self.resolved_data.clone()
}

async fn resolve(&mut self) -> Result<T, StdError> {
match (self.creator)(self.url.clone()).await {
Ok(data) => {
self.resolved_data = Some(data.clone());
Ok(data)
}
Err(err) => {
error!("create extension failed: {}", err);
Err(LoadExtensionError::new(
"load extension failed, create extension occur an error".to_string(),
)
.into())
}
}
}
}

pub(crate) struct LoadExtensionPromise<T> {
resolver: Arc<RwLock<ExtensionPromiseResolver<T>>>,
}

impl<T> LoadExtensionPromise<T>
where
T: Send + Clone + 'static,
{
pub(crate) fn new(creator: ExtensionCreator<T>, url: Url) -> Self {
let resolver = ExtensionPromiseResolver::new(creator, url);
LoadExtensionPromise {
resolver: Arc::new(RwLock::new(resolver)),
}
}

pub(crate) async fn resolve(&mut self) -> Result<T, StdError> {
onewe marked this conversation as resolved.
Show resolved Hide resolved
// get read lock
let resolver_read_lock = self.resolver.read().await;
// if extension is not None, return it
if let Some(extension) = resolver_read_lock.resolved_data() {
return Ok(extension);
}
drop(resolver_read_lock);

let mut write_lock = self.resolver.write().await;

match write_lock.resolved_data() {
Some(extension) => Ok(extension),
None => {
let extension = write_lock.resolve().await;
extension
}
}
}
}

impl<T> Clone for LoadExtensionPromise<T> {
fn clone(&self) -> Self {
LoadExtensionPromise {
resolver: self.resolver.clone(),
}
}
}

pub struct ExtensionDirectoryCommander {
sender: tokio::sync::mpsc::Sender<ExtensionOpt>,
}
Expand Down Expand Up @@ -280,7 +380,7 @@ pub trait Extension: Sealed {

fn name() -> String;

async fn create(url: &Url) -> Result<Self::Target, StdError>;
async fn create(url: Url) -> Result<Self::Target, StdError>;
}

#[allow(private_bounds)]
Expand Down
79 changes: 41 additions & 38 deletions dubbo/src/extension/registry_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use proxy::RegistryProxy;

use crate::extension::{
ConvertToExtensionFactories, Extension, ExtensionFactories, ExtensionMetaInfo, ExtensionType,
LoadExtensionPromise,
};

// extension://0.0.0.0/?extension-type=registry&extension-name=nacos&registry-url=nacos://127.0.0.1:8848
Expand Down Expand Up @@ -78,27 +79,9 @@ where
T: Extension<Target = Box<dyn Registry + Send + 'static>>,
{
fn convert_to_extension_factories() -> ExtensionFactories {
fn constrain<F>(f: F) -> F
where
F: for<'a> Fn(
&'a Url,
) -> Pin<
Box<
dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>>
+ Send
+ 'a,
>,
>,
{
f
}

let constructor = constrain(|url: &Url| {
let f = <T as Extension>::create(url);
Box::pin(f)
});

ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new(constructor))
ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new(
<T as Extension>::create,
))
}
}

Expand All @@ -108,19 +91,18 @@ pub(super) struct RegistryExtensionLoader {
}

impl RegistryExtensionLoader {
pub(crate) async fn register(
&mut self,
extension_name: String,
factory: RegistryExtensionFactory,
) {
pub(crate) fn register(&mut self, extension_name: String, factory: RegistryExtensionFactory) {
self.factories.insert(extension_name, factory);
}

pub(crate) async fn remove(&mut self, extension_name: String) {
pub(crate) fn remove(&mut self, extension_name: String) {
self.factories.remove(&extension_name);
}

pub(crate) async fn load(&mut self, url: &Url) -> Result<RegistryProxy, StdError> {
pub(crate) fn load(
&mut self,
url: Url,
) -> Result<LoadExtensionPromise<RegistryProxy>, StdError> {
let extension_name = url.query::<ExtensionName>().unwrap();
let extension_name = extension_name.value();
let factory = self.factories.get_mut(&extension_name).ok_or_else(|| {
Expand All @@ -129,19 +111,19 @@ impl RegistryExtensionLoader {
extension_name
))
})?;
factory.create(url).await
factory.create(url)
}
}

type RegistryConstructor = for<'a> fn(
&'a Url,
type RegistryConstructor = fn(
Url,
) -> Pin<
Box<dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>> + Send + 'a>,
Box<dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>> + Send>,
>;

pub(crate) struct RegistryExtensionFactory {
constructor: RegistryConstructor,
instances: HashMap<String, RegistryProxy>,
instances: HashMap<String, LoadExtensionPromise<RegistryProxy>>,
}

impl RegistryExtensionFactory {
Expand All @@ -154,7 +136,10 @@ impl RegistryExtensionFactory {
}

impl RegistryExtensionFactory {
pub(super) async fn create(&mut self, url: &Url) -> Result<RegistryProxy, StdError> {
pub(super) fn create(
&mut self,
url: Url,
) -> Result<LoadExtensionPromise<RegistryProxy>, StdError> {
let registry_url = url.query::<RegistryUrl>().unwrap();
let registry_url = registry_url.value();
let url_str = registry_url.as_str().to_string();
Expand All @@ -164,10 +149,28 @@ impl RegistryExtensionFactory {
Ok(proxy)
}
None => {
let registry = (self.constructor)(url).await?;
let proxy = <RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
self.instances.insert(url_str, proxy.clone());
Ok(proxy)
let constructor = self.constructor;

let creator = move |url: Url| {
let registry = constructor(url);
Box::pin(async move {
let registry = registry.await?;
let proxy =
<RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
Ok(proxy)
})
as Pin<
Box<
dyn Future<Output = Result<RegistryProxy, StdError>>
+ Send
+ 'static,
>,
>
};

let promise = LoadExtensionPromise::new(Box::new(creator), url);
self.instances.insert(url_str, promise.clone());
Ok(promise)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion dubbo/src/registry/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ impl Extension for StaticRegistry {
"static".to_string()
}

async fn create(url: &Url) -> Result<Self::Target, StdError> {
async fn create(url: Url) -> Result<Self::Target, StdError> {
// url example:
// extension://0.0.0.0?extension-type=registry&extension-name=static&registry=static://127.0.0.1
let static_invoker_urls = url.query::<StaticInvokerUrls>();
Expand Down
Loading
Loading