Skip to content

swiss-ai-center/djl-image-sam-example

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DJL Segment Anything Model (SAM)

Introduction

This repository is a Maven project that wraps the Segment Anything Model. You can read more about the converting process under the /pytorch_convert/README.md directory.

To interface with the TorchScript model, we used the DJL framework. DJL is a deep learning framework for Java that supports PyTorch, TensorFlow, and MXNet. It also provides a Java API to load and run TorchScript models.

Project Structure

The project is structured as follows:

Installation & Usage

It is recommended to use an IDE such as IntelliJ IDEA to run the project.

To install the dependencies, run the following command:

mvn clean install

To run the tests, run the following command:

mvn test

Implement your Own PyTorch Model

Before implementing a model with the DJL framework, you should first convert your model to TorchScript.

You can also find example from the DJL documentation here.

Add the Dependencies

Add the following dependencies to your pom.xml file:

<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>api</artifactId>
    <version>0.21.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.22.0</version>
    <scope>runtime</scope>
</dependency>

Note: You can find the latest version of the dependencies here.

Implement the Model

Create a new class for your model. Within the class you can load the TorchScript model and run inference. You can find an example here.

The main idea is to create the following objects:

Translator<Image, SamRawOutput> translator;
Criteria<Image, SamRawOutput> criteria;
ZooModel<Image, SamRawOutput> model;
Predictor<Image, SamRawOutput> predictor;

Each object has an input and output type which should match the input and output types of the translator object.

DJL has many input/output types as well as translators already implemented. You can find them here.

Translator

The translator object is used to convert the input/output tensors to/from the TorchScript model. You can find an example here.

It overrides the following methods:

  • processInput(TranslatorContext ctx, Image input) to convert the input image to a NDList object.
  • processOutput(TranslatorContext ctx, NDList list) to convert the output NDList object to a SamRawOutput object.

The SamRawOutput object is a custom class wrapper that contains the output tensors of the model. You can find an example here.

Criteria, Model, and Predictor

The criteria object is used to specify the input and output types of the model. You can find an example here.

Note: The path of the TorchScript model must be a directory that contains the .pt file and it must have the same name as the directory.

By calling the method criteria.loadModel();, the model object is created. You can find an example here.

Finally, the predictor object is created by calling the method model.newPredictor();. You can find an example here.

Tips

  • You can use the NDManager object to create NDArray objects.

Resources