This commit is contained in:
殷勇 2023-09-08 15:55:22 +08:00
parent 52dbd50f18
commit 18ba5355ae
5 changed files with 50 additions and 15 deletions

View File

@ -51,7 +51,7 @@ refresh_token varchar(255) not null default "",
client_ip varchar(255) not null default "",
user_agent varchar(255) NOT NULL default "",
is_blocked tinyint not null default "0",
expires_at int not null default "0",
expires_at datetime not null,
created_at int not null default "0",
primary key id(id),
unique key username(username)

View File

@ -26,6 +26,7 @@ go get -u github.com/o1egl/paseto
var TokenSymmetricKey = "12345678901234567890123456789012"
func NewServer(store *sqlx.DB) (*Server, error) {
// NewPasetoMaker or NewJWTMaker, First use NewPasetoMaker
tokenMaker, err := token.NewPasetoMaker(TokenSymmetricKey)
if err != nil {
return nil, fmt.Errorf("cannot create token maker: %w", err)

View File

@ -1,6 +1,7 @@
package api
import (
"fmt"
"github.com/gin-gonic/gin"
"main/model"
"net/http"
@ -76,11 +77,23 @@ func (server *Server) loginUser(ctx *gin.Context) {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
if !CheckPassword(req.Password, user.Password) {
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}
// 检查用户是否已经登录,
// 1. 更新会话信息, 2. 终止旧会话, 3. 阻止新登录
hasSession := server.IsUserLoggedIn(req.Username)
if hasSession != nil {
if hasSession.ExpiresAt.After(time.Now()) {
err = fmt.Errorf("session not expired")
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
}
accessTokenDuration, _ := time.ParseDuration("15m")
accessToken, accessPayload, err := server.tokenMaker.CreateToken(
user.Username,
@ -101,13 +114,14 @@ func (server *Server) loginUser(ctx *gin.Context) {
return
}
var createSessionParams model.Session
createSessionParams.Id = refreshPayload.ID
createSessionParams.Username = req.Username
createSessionParams.RefreshToken = refreshToken
createSessionParams.UserAgent = ctx.Request.UserAgent()
createSessionParams.ClientIp = ctx.ClientIP()
createSessionParams.ExpiresAt = refreshPayload.ExpiredAt
createSessionParams := model.Session{
Id: refreshPayload.ID,
Username: refreshPayload.Username,
RefreshToken: refreshToken,
ClientIp: ctx.ClientIP(),
UserAgent: ctx.Request.UserAgent(),
ExpiresAt: refreshPayload.ExpiredAt,
}
session, err := model.CreateSession(&createSessionParams)
if err != nil {
@ -142,3 +156,11 @@ func (server *Server) listUsers(ctx *gin.Context) {
func (server *Server) getUser(ctx *gin.Context) {
}
func (server *Server) IsUserLoggedIn(username string) *model.Session {
session := model.GetSessionByUsername(username)
if session != nil {
return session
}
return nil
}

View File

@ -12,7 +12,7 @@ var db *sqlx.DB
func InitDB() {
var err error
dsn := "root:keji178@tcp(login-test.kingsome.cn:3306)/admindb_dev?charset=utf8mb4&parseTime=True"
dsn := "root:keji178@tcp(login-test.kingsome.cn:3306)/admindb_dev?charset=utf8mb4&parseTime=True&loc=Local"
db, err = sqlx.Connect("mysql", dsn)
if err != nil {
f5.GetSysLog().Info("Failed to connect to the database err:%v \n", err)

View File

@ -1,7 +1,7 @@
package model
import (
"log"
"database/sql"
"main/db"
"time"
)
@ -10,22 +10,34 @@ type Session struct {
Id int64 `db:"id"`
Username string `db:"username"`
RefreshToken string `db:"refresh_token"`
UserAgent string `db:"user_agent"`
ClientIp string `db:"client_ip"`
UserAgent string `db:"user_agent"`
ExpiresAt time.Time `db:"expires_at"`
}
func CreateSession(session *Session) (*Session, error) {
query := "INSERT INTO t_sessions (id, username, refresh_token, client_ip, user_agent, expires_at) VALUES (?, ?, ?, ?, ?, ?)"
result, err := db.GetDB().Exec(query, session.Id, session.Username, session.RefreshToken, session.UserAgent, session.ClientIp, session.ExpiresAt)
query := "REPLACE INTO t_sessions (id, username, refresh_token, client_ip, user_agent, expires_at) VALUES (?, ?, ?, ?, ?, ?)"
result, err := db.GetDB().Exec(query, session.Id, session.Username, session.RefreshToken, session.ClientIp, session.UserAgent, session.ExpiresAt)
if err != nil {
log.Printf("Error creating session: %v", err)
return nil, err
}
_, err = result.LastInsertId()
if err != nil {
log.Printf("Error getting last insert ID: %v", err)
return nil, err
}
return session, nil
}
func GetSessionByUsername(username string) *Session {
query := "SELECT id, username, refresh_token, client_ip, user_agent, expires_at FROM t_sessions WHERE username = ?"
var session Session
err := db.GetDB().Get(&session, query, username)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
return nil
}
return &session
}