-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.js
103 lines (80 loc) · 2.18 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class Maff {
static random(min = 0, max = 1) {
return Math.random() * (max - min) + min;
}
}
class Perceptron {
weights = [
null,
null,
];
learningRate = 0.00001;
constructor () {
this.weights = this.weights.map(weight => Maff.random(-1, 1));
}
guess (inputs = []) {
// sigma (sum)
const sum = inputs.reduce((acc, input, key) => {
return acc + (input * this.weights[key]);
}, 0);
// step function
return Math.sign(sum);
}
train (inputs = [], target = 0) {
const error = target - this.guess(inputs);
for (let index = 0; index < this.weights.length; index++) {
this.weights[index] = this.fitWeight(
this.weights[index],
inputs[index],
error
);
}
}
// gradient descent
fitWeight (weight, input, error) {
return weight + (error * input * this.learningRate);
}
}
const points = [];
const perceptron = new Perceptron();
function setup () {
const width = 500;
const height = 500;
createCanvas(width, height);
background('#ccc');
line(0, 0, width, height);
// generating dataset
for (let index = 0; index < 100; index++) {
const x = Maff.random(0, width);
const y = Maff.random(0, height);
const point = {
x,
y,
label: x > y ? 1 : -1,
color: x > y ? 'white' : 'black',
};
points.push(point);
}
// printing dataset
for (const point of points) {
fill(point.color);
noStroke();
circle(point.x, point.y, 10);
}
frameRate(1);
}
function draw () {
// const output = perceptron.guess([-1, 0.5]);
// console.log(output);
for (const point of points) {
let inputs = [point.x, point.y];
perceptron.train(inputs, point.label);
const guess = perceptron.guess(inputs);
const color = guess === point.label ? 'lime' : 'red';
fill(color);
noStroke();
circle(point.x, point.y, 5);
// fill(point.color);
// circle(point.x, point.y, 10);
}
}