/
hits.py
42 lines (33 loc) · 959 Bytes
/
hits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
from collections import defaultdict
from .base import Metric
from typing import Any
np.seterr(all="raise")
class Hits(Metric):
"""Hits.
Number of recommendations made successfully.
(right predictions)
"""
def __init__(self, *args, **kwargs):
"""__init__.
Args:
args:
kwargs:
"""
super().__init__(*args, **kwargs)
self.users_true_positive = defaultdict(int)
def compute(self, uid: int):
"""compute.
Args:
uid (int): user id
"""
return self.users_true_positive[uid]
def update_recommendation(self, uid: int, item: int, reward: float):
"""update_recommendation.
Args:
uid (int): user id
item (int): item id
reward (float): reward
"""
if self.relevance_evaluator.is_relevant(reward):
self.users_true_positive[uid] += 1