Skip to content

Commit

Permalink
Feature/SK-613 | Display model trail in studio (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Jan 11, 2024
1 parent e23421a commit 96a7044
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 0 deletions.
98 changes: 98 additions & 0 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,24 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None

return jsonify(result)

def get_model(self, model_id: str):
result = self.statestore.get_model(model_id)

if result is None:
return (
jsonify({"success": False, "message": "No model found."}),
404,
)

payload = {
"committed_at": result["committed_at"],
"parent_model": result["parent_model"],
"model": result["model"],
"session_id": result["session_id"],
}

return jsonify(payload)

def get_model_trail(self):
"""Get the model trail for a given session.
Expand All @@ -784,6 +802,86 @@ def get_model_trail(self):
{"success": False, "message": "No model trail available."}
)

def get_model_ancestors(self, model_id: str, limit: str = None):
"""Get the model ancestors for a given model.
:param model_id: The model id to get the model ancestors for.
:type model_id: str
:param limit: The number of ancestors to return.
:type limit: str
:return: The model ancestors for the given model as a json response.
:rtype: :class:`flask.Response`
"""
if model_id is None:
return jsonify(
{"success": False, "message": "No model id provided."}
)

limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10

response = self.statestore.get_model_ancestors(model_id, limit)
if response:

arr: list = []

for element in response:
obj = {
"model": element["model"],
"committed_at": element["committed_at"],
"session_id": element["session_id"],
"parent_model": element["parent_model"],
}
arr.append(obj)

result = {"result": arr}

return jsonify(result)
else:
return jsonify(
{"success": False, "message": "No model ancestors available."}
)

def get_model_descendants(self, model_id: str, limit: str = None):
"""Get the model descendants for a given model.
:param model_id: The model id to get the model descendants for.
:type model_id: str
:param limit: The number of descendants to return.
:type limit: str
:return: The model descendants for the given model as a json response.
:rtype: :class:`flask.Response`
"""

if model_id is None:
return jsonify(
{"success": False, "message": "No model id provided."}
)

limit: int = int(limit) if limit is not None else 10

response: list = self.statestore.get_model_descendants(model_id, limit)

if response:

arr: list = []

for element in response:
obj = {
"model": element["model"],
"committed_at": element["committed_at"],
"session_id": element["session_id"],
"parent_model": element["parent_model"],
}
arr.append(obj)

result = {"result": arr}

return jsonify(result)
else:
return jsonify(
{"success": False, "message": "No model descendants available."}
)

def get_all_rounds(self):
"""Get all rounds.
Expand Down
47 changes: 47 additions & 0 deletions fedn/fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,38 @@ def get_model_trail():
return api.get_model_trail()


@app.route("/get_model_ancestors", methods=["GET"])
def get_model_ancestors():
"""Get the ancestors of a model.
param: model: The model id to get the ancestors for.
type: model: str
param: limit: The maximum number of ancestors to return.
type: limit: int
return: A list of model objects that the model derives from.
rtype: json
"""
model = request.args.get("model", None)
limit = request.args.get("limit", None)

return api.get_model_ancestors(model, limit)


@app.route("/get_model_descendants", methods=["GET"])
def get_model_descendants():
"""Get the ancestors of a model.
param: model: The model id to get the child for.
type: model: str
param: limit: The maximum number of descendants to return.
type: limit: int
return: A list of model objects that are descendents of the provided model id.
rtype: json
"""
model = request.args.get("model", None)
limit = request.args.get("limit", None)

return api.get_model_descendants(model, limit)


@app.route("/list_models", methods=["GET"])
def list_models():
"""Get models from the statestore.
Expand All @@ -50,6 +82,21 @@ def list_models():
return api.get_models(session_id, limit, skip, include_active)


@app.route("/get_model", methods=["GET"])
def get_model():
"""Get a model from the statestore.
param: model: The model id to get.
type: model: str
return: The model as a json object.
rtype: json
"""
model = request.args.get("model", None)
if model is None:
return jsonify({"success": False, "message": "Missing model id."}), 400

return api.get_model(model)


@app.route("/delete_model_trail", methods=["GET", "POST"])
def delete_model_trail():
"""Delete the model trail for a given session.
Expand Down
74 changes: 74 additions & 0 deletions fedn/fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,20 @@ def set_latest_model(self, model_id, session_id=None):
"""

committed_at = datetime.now()
current_model = self.model.find_one({"key": "current_model"})
parent_model = None

# if session_id is set the it means the model is generated from a session
# and we need to set the parent model
# if not the model is uploaded by the user and we don't need to set the parent model
if session_id is not None:
parent_model = current_model["model"] if current_model and "model" in current_model else None

self.model.insert_one(
{
"key": "models",
"model": model_id,
"parent_model": parent_model,
"session_id": session_id,
"committed_at": committed_at,
}
Expand Down Expand Up @@ -534,6 +543,71 @@ def get_model_trail(self):
except (KeyError, IndexError):
return None

def get_model_ancestors(self, model_id: str, limit: int):
"""Get the model ancestors.
:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of ancestors to return.
:type limit: int
:return: List of model ancestors.
:rtype: list
"""
model = self.model.find_one({"key": "models", "model": model_id})
current_model_id = model["parent_model"] if model is not None else None
result = []

for _ in range(limit):
if current_model_id is None:
break

model = self.model.find_one({"key": "models", "model": current_model_id})

if model is not None:
result.append(model)
current_model_id = model["parent_model"]

return result

def get_model_descendants(self, model_id: str, limit: int):
"""Get the model descendants.
:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of descendants to return.
:type limit: int
:return: List of model descendants.
:rtype: list
"""

model: object = self.model.find_one({"key": "models", "model": model_id})
current_model_id: str = model["model"] if model is not None else None
result: list = []

for _ in range(limit):
if current_model_id is None:
break

model: str = self.model.find_one({"key": "models", "parent_model": current_model_id})

if model is not None:
result.append(model)
current_model_id = model["model"]

result.reverse()

return result

def get_model(self, model_id):
"""Get model with id.
:param model_id: id of model to get
:type model_id: str
:return: model with id
:rtype: ObjectId
"""
return self.model.find_one({"key": "models", "model": model_id})

def get_events(self, **kwargs):
"""Get events from the database.
Expand Down

0 comments on commit 96a7044

Please sign in to comment.