Skip to content

Commit

Permalink
change loader and proc params input pattern to align with std map
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Apr 29, 2024
1 parent 6418503 commit 3a86daa
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions src/collective/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpC
thread_local CommunicatorType Communicator::type_{};
thread_local std::string Communicator::nccl_path_{};

std::map<std::string, std::string> json_to_map(xgboost::Json const& config, std::string key) {
auto json_map = xgboost::OptionalArg<xgboost::Object>(config, key, xgboost::JsonObject::Map{});
std::map<std::string, std::string> params{};
for (auto entry : json_map) {
std::string text;
xgboost::Value* value = &(entry.second.GetValue());
if (value->Type() == xgboost::Value::ValueKind::kString) {
text = reinterpret_cast<xgboost::String *>(value)->GetString();
} else if (value->Type() == xgboost::Value::ValueKind::kInteger) {
auto num = reinterpret_cast<xgboost::Integer *>(value)->GetInteger();
text = std::to_string(num);
} else if (value->Type() == xgboost::Value::ValueKind::kNumber) {
auto num = reinterpret_cast<xgboost::Number *>(value)->GetNumber();
text = std::to_string(num);
} else {
text = "Unsupported type ";
}
params[entry.first] = text;
}
return params;
}

void Communicator::Init(Json const& config) {
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
nccl_path_ = nccl;
Expand Down Expand Up @@ -50,30 +72,27 @@ void Communicator::Init(Json const& config) {
std::string proc_params_key{};
std::string proc_params_map{};
plugin_name = OptionalArg<String>(config, "plugin_name", plugin_name);
loader_params_key = OptionalArg<String>(config, "loader_params_key", loader_params_key);
loader_params_map = OptionalArg<String>(config, "loader_params_map", loader_params_map);
proc_params_key = OptionalArg<String>(config, "proc_params_key", proc_params_key);
proc_params_map = OptionalArg<String>(config, "proc_params_map", proc_params_map);
// Initialize processor if plugin_name is provided
if (!plugin_name.empty()) {
std::map<std::string, std::string> loader_params = {{loader_params_key, loader_params_map}};
std::map<std::string, std::string> proc_params = {{proc_params_key, proc_params_map}};
std::map<std::string, std::string> loader_params = json_to_map(config, "loader_params");
std::map<std::string, std::string> proc_params = json_to_map(config, "proc_params");
processing::ProcessorLoader loader(loader_params);
processor_instance = loader.load(plugin_name);
processor_instance->Initialize(collective::GetRank() == 0, proc_params);
}
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
break;
}

case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}

Expand Down

0 comments on commit 3a86daa

Please sign in to comment.