forked from flatfeestack/fastauth
/
db.go
175 lines (147 loc) · 5.73 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
package main
import (
"database/sql"
"fmt"
"time"
)
type dbRes struct {
sms *string
password []byte
role []byte
salt []byte
emailVerified *time.Time
smsVerified *time.Time
totpVerified *time.Time
refreshToken *string
totp *string
errorCount *int
}
func dbSelect(email string) (*dbRes, error) {
var res dbRes
err := db.
QueryRow("SELECT sms, password, role, salt, emailVerified, refreshToken, totp, smsVerified, totpVerified, errorCount FROM auth WHERE email = ?", email).
Scan(&res.sms, &res.password, &res.role, &res.salt, &res.emailVerified, &res.refreshToken, &res.totp, &res.smsVerified, &res.totpVerified, &res.errorCount)
if err != nil {
return nil, err
}
return &res, nil
}
func insertUser(salt []byte, email string, dk []byte, emailToken string, refreshToken string) error {
stmt, err := db.Prepare("INSERT INTO auth (email, password, role, salt, emailToken, refreshToken) VALUES (?, ?, 'USR', ?, ?, ?)")
if err != nil {
return fmt.Errorf("prepare INSERT INTO auth for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(email, dk, salt, emailToken, refreshToken)
return handleErr(res, err, "INSERT INTO auth", email)
}
func updateRefreshToken(oldRefreshToken string, newRefreshToken string) error {
stmt, err := db.Prepare("UPDATE auth SET refreshToken = ? WHERE refreshToken = ?")
if err != nil {
return fmt.Errorf("prepare UPDATE refreshTokenfor statement failed: %v", err)
}
defer stmt.Close()
res, err := stmt.Exec(newRefreshToken, oldRefreshToken)
return handleErr(res, err, "UPDATE refreshToken", "n/a")
}
func resetPassword(salt []byte, email string, dk []byte, forgetEmailToken string) error {
stmt, err := db.Prepare("UPDATE auth SET password = ?, salt = ?, totp = NULL, sms = NULL WHERE email = ? AND forgetEmailToken = ?")
if err != nil {
return fmt.Errorf("prepare UPDATE auth password for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(dk, salt, email, forgetEmailToken)
return handleErr(res, err, "UPDATE auth password", email)
}
func updateEmailForgotToken(email string, token string) error {
//TODO: don't accept too old forget tokens
stmt, err := db.Prepare("UPDATE auth SET forgetEmail = CURRENT_TIMESTAMP, forgetEmailToken = ? WHERE email = ?")
if err != nil {
return fmt.Errorf("prepare UPDATE auth forgetEmailToken for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(token, email)
return handleErr(res, err, "UPDATE auth forgetEmailToken", email)
}
func updateTOTP(email string, totp string) error {
stmt, err := db.Prepare("UPDATE auth SET totp = ? WHERE email = ? and totp IS NULL")
if err != nil {
return fmt.Errorf("prepare UPDATE auth totp for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(totp, email)
return handleErr(res, err, "UPDATE auth totp", email)
}
func updateSMS(email string, totp string, sms string) error {
stmt, err := db.Prepare("UPDATE auth SET totp = ?, sms = ? WHERE email = ? AND smsVerified IS NULL")
if err != nil {
return fmt.Errorf("prepare UPDATE auth totp for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(totp, sms, email)
return handleErr(res, err, "UPDATE auth totp", email)
}
func updateEmailToken(email string, token string) error {
stmt, err := db.Prepare("UPDATE auth SET emailVerified = CURRENT_TIMESTAMP, emailToken = NULL WHERE email = ? AND emailToken = ?")
if err != nil {
return fmt.Errorf("prepare UPDATE auth for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(email, token)
return handleErr(res, err, "UPDATE auth", email)
}
func updateSMSVerified(email string) error {
stmt, err := db.Prepare("UPDATE auth SET smsVerified = CURRENT_TIMESTAMP WHERE email = ? AND sms IS NOT NULL")
if err != nil {
return fmt.Errorf("prepare UPDATE auth for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(email)
return handleErr(res, err, "UPDATE auth SMS timestamp", email)
}
func updateTOTPVerified(email string) error {
stmt, err := db.Prepare("UPDATE auth SET totpVerified = CURRENT_TIMESTAMP WHERE email = ? AND totp IS NOT NULL")
if err != nil {
return fmt.Errorf("prepare UPDATE auth for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(email)
return handleErr(res, err, "UPDATE auth totp timestamp", email)
}
func updateMailStatus(email string) error {
stmt, err := db.Prepare("UPDATE auth set emailSent = CURRENT_TIMESTAMP WHERE email = ?")
if err != nil {
return fmt.Errorf("prepare auth auth status for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(email)
return handleErr(res, err, "UPDATE auth status", email)
}
func incErrorCount(email string) error {
stmt, err := db.Prepare("UPDATE auth set errorCount = errorCount + 1 WHERE email = ?")
if err != nil {
return fmt.Errorf("prepare UPDATE auth status for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(email)
return handleErr(res, err, "UPDATE auth status", email)
}
func resetCount(email string) error {
stmt, err := db.Prepare("UPDATE auth set errorCount = 0 WHERE email = ?")
if err != nil {
return fmt.Errorf("prepare UPDATE auth status for %v statement failed: %v", email, err)
}
defer stmt.Close()
res, err := stmt.Exec(email)
return handleErr(res, err, "UPDATE auth status", email)
}
func handleErr(res sql.Result, err error, info string, email string) error {
if err != nil {
return fmt.Errorf("%v query %v failed: %v", info, email, err)
}
nr, err := res.RowsAffected()
if nr == 0 || err != nil {
return fmt.Errorf("%v %v rows %v, affected or err: %v", info, nr, email, err)
}
return nil
}