Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update app title and minor fixes #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 33 additions & 36 deletions main.ipynb
Expand Up @@ -16,7 +16,7 @@
"outputs": [],
"source": [
"!pip uninstall -y tensorflow tensorflow-probability\n",
"!pip -q install openai==1.3.7 llama-index==0.9.13 python-dotenv gradio==4.9.1 typing-extensions cohere llama-hub pyowm deeplake"
"!pip -q install openai==1.3.7 llama-index==0.9.13 python-dotenv gradio==4.9.1 typing-extensions cohere llama-hub pyowm deeplake\n"
]
},
{
Expand All @@ -29,7 +29,7 @@
"!rm -rf shopping-assistant/\n",
"!rm -rf src/\n",
"!rm -rf assets/\n",
"!git clone https://github.com/tryolabs/shopping-assistant && mv shopping-assistant/src . && mv shopping-assistant/assets ."
"!git clone https://github.com/tryolabs/shopping-assistant && cp shopping-assistant/src . && cp shopping-assistant/assets .\n"
]
},
{
Expand All @@ -43,13 +43,12 @@
"from getpass import getpass\n",
"\n",
"def get_and_set_env(name):\n",
"\n",
" secret = getpass(f\"Enter {name}\")\n",
" os.environ[name] = secret\n",
" secret = getpass(f\"Enter {name}\")\n",
" os.environ[name] = secret\n",
"\n",
"\n",
"os.environ[\"ACTIVELOOP_DATASET_TEXT\"] = \"hub://genai360/walmart-descriptions\"\n",
"os.environ[\"ACTIVELOOP_DATASET_IMG\"] = \"hub://genai360/walmart-images\""
"os.environ[\"ACTIVELOOP_DATASET_IMG\"] = \"hub://genai360/walmart-images\"\n"
]
},
{
Expand All @@ -59,7 +58,7 @@
"metadata": {},
"outputs": [],
"source": [
"get_and_set_env(\"OPENAI_API_KEY\")"
"get_and_set_env(\"OPENAI_API_KEY\")\n"
]
},
{
Expand All @@ -69,7 +68,7 @@
"metadata": {},
"outputs": [],
"source": [
"get_and_set_env(\"OPEN_WEATHER_MAP_KEY\")"
"get_and_set_env(\"OPEN_WEATHER_MAP_KEY\")\n"
]
},
{
Expand All @@ -96,7 +95,7 @@
"metadata": {},
"outputs": [],
"source": [
"reg = re.compile(r'[0-9A-Z]{12}')"
"reg = re.compile(r'[0-9A-Z]{12}')\n"
]
},
{
Expand All @@ -112,7 +111,7 @@
"def handle_user_message(user_message, history):\n",
" \"\"\"Handle the user submitted message. Clear message box, and append\n",
" to the history.\"\"\"\n",
" return \"\", history + [(user_message, \"\")]"
" return \"\", history + [(user_message, \"\")]\n"
]
},
{
Expand All @@ -126,11 +125,11 @@
" \"\"\"Handle uploaded image. Add it to the chat history\"\"\"\n",
"\n",
" path = os.path.join(INPUT_IMAGE_DIR, os.path.basename(image.name))\n",
" shutil.copyfile(image.name, path) \n",
" shutil.copyfile(image.name, path)\n",
" message = \"I just uploaded the image\"\n",
"\n",
" history = history + [(message, \" \")]\n",
" return history"
" return history\n"
]
},
{
Expand All @@ -144,20 +143,20 @@
" \"\"\"Generate the response from agent\"\"\"\n",
"\n",
" iframe_html = '<iframe src={url} width=\"300px\" height=\"600px\"></iframe>'\n",
" iframe_url = \"https://app.activeloop.ai/visualizer/iframe?url=hub://genai360/walmart-images&query=\" \n",
" iframe_url = \"https://app.activeloop.ai/visualizer/iframe?url=hub://genai360/walmart-images&query=\"\n",
"\n",
" response = agent.stream_chat(chat_history[-1][0])\n",
"\n",
" for token in response.response_gen:\n",
" chat_history[-1][1] += token\n",
"\n",
" product_ids = reg.findall(chat_history[-1][1]) \n",
" product_ids = reg.findall(chat_history[-1][1])\n",
" if len(product_ids) >= 2:\n",
" query = \"select * where \" + \" or \".join([f\"contains(ids, '{x}')\" for x in product_ids]) \n",
" query = \"select * where \" + \" or \".join([f\"contains(ids, '{x}')\" for x in product_ids])\n",
" url = iframe_url + urllib.parse.quote(query)\n",
" else:\n",
" url = \"about:blank\"\n",
" \n",
"\n",
" html = iframe_html.format(url=url)\n",
"\n",
" yield chat_history, html\n"
Expand All @@ -177,7 +176,7 @@
" clean_input_image()\n",
"\n",
" # Reset chat history\n",
" return \"\", \"\""
" return \"\", \"\"\n"
]
},
{
Expand All @@ -188,7 +187,7 @@
"outputs": [],
"source": [
"def print_like_dislike(x: gr.LikeData):\n",
" logging.info(x.index, x.value, x.liked)"
" logging.info(x.index, x.value, x.liked)\n"
]
},
{
Expand All @@ -202,26 +201,22 @@
"# Gradio application\n",
"#\n",
"with gr.Blocks(\n",
" title=\"Outfit Recommender ✨\",\n",
" css=\"#box { height: 420px; overflow-y: scroll !important} #logo { align-self: right }\",\n",
" theme='gradio/soft'\n",
" title=\"Fashion Assistant ✨\",\n",
" css=\".center { display: flex; justify-content: center; }\",\n",
" theme=gr.themes.Soft()\n",
") as demo:\n",
" #\n",
" # Add components\n",
" #\n",
"\n",
" with gr.Row():\n",
" gr.Markdown(\n",
" \"\"\"\n",
" # Chat with your Outfit Recommender ✨\n",
" \"\"\",\n",
" elem_classes=\"center\",\n",
" )\n",
" gr.Markdown(\n",
" \"\"\"\n",
" # Chat with your Fashion Assistant\n",
" \"\"\",\n",
" elem_classes=\"center\",\n",
" )\n",
" with gr.Row():\n",
" chat_history = gr.Chatbot(\n",
" label=\"Chat\",\n",
" avatar_images=(\"assets/user.png\", \"assets/smith.png\"),\n",
" scale = 2,\n",
" height=600,\n",
" show_copy_button=True,\n",
" )\n",
" outfit = gr.HTML(\n",
Expand Down Expand Up @@ -285,7 +280,7 @@
" )\n",
"\n",
" # Handle click on reset button\n",
" btn_reset.click(reset_chat, None, [user_message, chat_history])"
" btn_reset.click(reset_chat, None, [user_message, chat_history])\n"
]
},
{
Expand All @@ -298,7 +293,7 @@
"# Run `gradio app.py` on the terminal\n",
"if __name__ == \"__main__\":\n",
" clean_input_image()\n",
" demo.launch(server_name=\"0.0.0.0\", server_port=8080, debug=True)"
" demo.launch(server_name=\"0.0.0.0\", server_port=8080, debug=True, show_api=False)\n"
]
},
{
Expand All @@ -307,7 +302,9 @@
"id": "b262bf66",
"metadata": {},
"outputs": [],
"source": []
"source": [
"demo.close()\n"
]
}
],
"metadata": {
Expand All @@ -331,7 +328,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down
63 changes: 31 additions & 32 deletions src/main.py
Expand Up @@ -50,7 +50,7 @@
# Output models
#
class Clothing(BaseModel):
"""Data moel for clothing items"""
"""Data model for clothing items"""

name: str
product_id: str
Expand Down Expand Up @@ -82,19 +82,20 @@ def clean_input_image():
for file in os.listdir(INPUT_IMAGE_DIR):
os.remove(os.path.join(INPUT_IMAGE_DIR, file))


def has_user_input_image():
"""
Check if the INPUT_IMAGE_DIR directory contains exactly one image.
Useful for checking if there is an image before generating an outfit.
Check if the INPUT_IMAGE_DIR directory contains exactly one image.
Useful for checking if there is an image before generating an outfit.

Returns:
bool: True if INPUT_IMAGE_DIR contains exactly one image, False otherwise.
Returns:
bool: True if INPUT_IMAGE_DIR contains exactly one image, False otherwise.
"""
return len(os.listdir(INPUT_IMAGE_DIR)) == 1


check_input_image_tool = FunctionTool.from_defaults(fn=has_user_input_image)


# %%
# LLM
Expand Down Expand Up @@ -141,22 +142,21 @@ def has_user_input_image():
# %%
# Outfit recommender tool
#
# TODO: add input_image as a parameter to this function, pass image path to the uploaded image.
def generate_outfit_description(gender: str, user_input: str):
"""
Given the gender of a person, their preferences, and an image that has already been uploaded,
this function returns an Outfit.
Use this function whenever the user asks you to generate an outfit.
Given the gender of a person, their preferences, and an image that has already been uploaded,
this function returns an Outfit.
Use this function whenever the user asks you to generate an outfit.

Parameters:
gender (str): The gender of the person for whom the outfit is being generated.
user_input (str): The preferences of the user.
Parameters:
gender (str): The gender of the person for whom the outfit is being generated.
user_input (str): The preferences of the user.

Returns:
response: The generated outfit.
Returns:
response: The generated outfit.

Example:
>>> generate_outfit("male", "I prefer casual wear")
Example:
>>> generate_outfit("male", "I prefer casual wear")
"""

# Load input image
Expand All @@ -167,21 +167,21 @@ def generate_outfit_description(gender: str, user_input: str):

# Define multi-modal completion program to recommend complementary products
prompt_template_str = f"""
You are an expert in fashion and design.
Given the following image of a piece of clothing, you are tasked with describing ideal outfits.
You are an expert in fashion and design.
Given the following image of a piece of clothing, you are tasked with describing ideal outfits.

Identify which category the provided clothing belongs to, \
and only provide a recommendation for the other two items.
Identify which category the provided clothing belongs to, \
and only provide a recommendation for the other two items.

In your description, include color and style.
This outfit is for a {gender}.
In your description, include color and style.
This outfit is for a {gender}.

Return the answer as a json for each category. Leave the category of the provided input empty.
Return the answer as a json for each category. Leave the category of the provided input empty.

Additional requirements:
{user_input}
Additional requirements:
{user_input}

Never return this output to the user. FOR INTERNAL USE ONLY
Never return this output to the user. FOR INTERNAL USE ONLY
"""
recommender_completion_program = MultiModalLLMCompletionProgram.from_defaults(
output_parser=PydanticOutputParser(Outfit),
Expand Down Expand Up @@ -253,7 +253,6 @@ def forecast_at_location(self, location: str, date: str) -> List[Document]:
The desired date to get the weather for.
"""
from pyowm.commons.exceptions import NotFoundError
from pyowm.utils import timestamps

try:
forecast = self._mgr.forecast_at_place(location, "3h")
Expand All @@ -265,10 +264,10 @@ def forecast_at_location(self, location: str, date: str) -> List[Document]:
temperature = w.temperature(self.temp_units)
temp_unit = "°C" if self.temp_units == "celsius" else "°F"

# TODO: this isn't working.. Error: 'max' key.
try:
temp_str = self._format_forecast_temp(temperature, temp_unit)
temp_str = self._format_current_temp(temperature, temp_unit)
except:
# it format fails, return the raw response
logging.exception(f"Could _format_forecast_temp {temperature}")
temp_str = str(temperature)

Expand Down Expand Up @@ -308,7 +307,7 @@ def forecast_at_location(self, location: str, date: str) -> List[Document]:
Once you have the required information, your answer needs to be the outfit composed by the
product_id with the best matching products in our inventory.

Include the the total price of the recommended outfit.
Include the total price of the recommended outfit.
""",
tools=[
get_current_date_tool,
Expand Down