Skip to content

Commit

Permalink
Merge pull request #16 from fourlastor-alexandria/context-class-loader
Browse files Browse the repository at this point in the history
Set context class loader from main class
  • Loading branch information
fourlastor committed Jan 26, 2024
2 parents c63be37 + 9aa5eb3 commit fab2f18
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct Config {
mainClass: String,
vmArgs: Vec<String>,
useZgcIfSupportedOs: bool,
useMainAsContextClassLoader: bool,
}

// Picks discrete GPU on Windows, if possible
Expand Down Expand Up @@ -41,9 +42,10 @@ const JVM_LOCATION: [&str; 3] = ["jdk", "lib", "server"];
fn start_jvm(
jvm_location: &Path,
class_path: Vec<String>,
main_class: &str,
main_class_name: &str,
vm_args: Vec<String>,
use_zgc_if_supported: bool,
use_main_as_context_class_loader: bool,
args: Vec<String>,
) {
let mut args_builder = InitArgsBuilder::new()
Expand Down Expand Up @@ -75,6 +77,44 @@ fn start_jvm(
.attach_current_thread()
.expect("Failed to attach the current thread");

if use_main_as_context_class_loader {
// Class mainClass = MainClass.class;
let main_class = env
.find_class(main_class_name)
.expect("Failed to get main class");

// ClassLoader loader = mainClass.getClassLoader()
let class_loader = env
.call_method(
main_class,
"getClassLoader",
"()Ljava/lang/ClassLoader;",
&[],
)
.and_then(|it| it.l())
.expect("Failed to get class loader from main class");

// Thread thread = Thread.currentThread()
let current_thread = env
.call_static_method(
"java/lang/Thread",
"currentThread",
"()Ljava/lang/Thread;",
&[],
)
.and_then(|it| it.l())
.expect("Failed to get current thread");

// thread.setContextClassLoader(loader)
env.call_method(
current_thread,
"setContextClassLoader",
"(Ljava/lang/ClassLoader;)V",
&[(&class_loader).into()],
)
.expect("Failed to set class loader");
}

let jstrings: Vec<JString> = args
.iter()
.map(|s| env.new_string(s)) // Convert to JString (maybe)
Expand All @@ -92,7 +132,7 @@ fn start_jvm(
i = i + 1;
}
env.call_static_method(
main_class,
main_class_name,
"main",
"([Ljava/lang/String;)V",
&[(&method_args).into()],
Expand Down Expand Up @@ -168,13 +208,24 @@ fn main() {
let config_file_path = current_location.join("config.json");
let data = fs::read_to_string(config_file_path).expect("Unable to read config file");
let config: Config = serde_json::from_str(&data).expect("Invalid config json");
let class_path: Vec<String> = config.classPath.into_iter().map(|it| current_location.join(it).into_os_string().into_string().unwrap()).collect();
let class_path: Vec<String> = config
.classPath
.into_iter()
.map(|it| {
current_location
.join(it)
.into_os_string()
.into_string()
.unwrap()
})
.collect();
start_jvm(
&jvm_location,
class_path,
&config.mainClass.replace(".", "/"),
config.vmArgs,
config.useZgcIfSupportedOs,
config.useMainAsContextClassLoader,
args,
);
}

0 comments on commit fab2f18

Please sign in to comment.