This commit is contained in:
殷勇 2023-09-08 15:55:22 +08:00
parent 52dbd50f18
commit 18ba5355ae
5 changed files with 50 additions and 15 deletions

View File

@ -51,7 +51,7 @@ refresh_token varchar(255) not null default "",
client_ip varchar(255) not null default "", client_ip varchar(255) not null default "",
user_agent varchar(255) NOT NULL default "", user_agent varchar(255) NOT NULL default "",
is_blocked tinyint not null default "0", is_blocked tinyint not null default "0",
expires_at int not null default "0", expires_at datetime not null,
created_at int not null default "0", created_at int not null default "0",
primary key id(id), primary key id(id),
unique key username(username) unique key username(username)

View File

@ -26,6 +26,7 @@ go get -u github.com/o1egl/paseto
var TokenSymmetricKey = "12345678901234567890123456789012" var TokenSymmetricKey = "12345678901234567890123456789012"
func NewServer(store *sqlx.DB) (*Server, error) { func NewServer(store *sqlx.DB) (*Server, error) {
// NewPasetoMaker or NewJWTMaker, First use NewPasetoMaker
tokenMaker, err := token.NewPasetoMaker(TokenSymmetricKey) tokenMaker, err := token.NewPasetoMaker(TokenSymmetricKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot create token maker: %w", err) return nil, fmt.Errorf("cannot create token maker: %w", err)

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"main/model" "main/model"
"net/http" "net/http"
@ -76,11 +77,23 @@ func (server *Server) loginUser(ctx *gin.Context) {
ctx.JSON(http.StatusInternalServerError, errorResponse(err)) ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return return
} }
if !CheckPassword(req.Password, user.Password) { if !CheckPassword(req.Password, user.Password) {
ctx.JSON(http.StatusUnauthorized, errorResponse(err)) ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return return
} }
// 检查用户是否已经登录,
// 1. 更新会话信息, 2. 终止旧会话, 3. 阻止新登录
hasSession := server.IsUserLoggedIn(req.Username)
if hasSession != nil {
if hasSession.ExpiresAt.After(time.Now()) {
err = fmt.Errorf("session not expired")
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
}
accessTokenDuration, _ := time.ParseDuration("15m") accessTokenDuration, _ := time.ParseDuration("15m")
accessToken, accessPayload, err := server.tokenMaker.CreateToken( accessToken, accessPayload, err := server.tokenMaker.CreateToken(
user.Username, user.Username,
@ -101,13 +114,14 @@ func (server *Server) loginUser(ctx *gin.Context) {
return return
} }
var createSessionParams model.Session createSessionParams := model.Session{
createSessionParams.Id = refreshPayload.ID Id: refreshPayload.ID,
createSessionParams.Username = req.Username Username: refreshPayload.Username,
createSessionParams.RefreshToken = refreshToken RefreshToken: refreshToken,
createSessionParams.UserAgent = ctx.Request.UserAgent() ClientIp: ctx.ClientIP(),
createSessionParams.ClientIp = ctx.ClientIP() UserAgent: ctx.Request.UserAgent(),
createSessionParams.ExpiresAt = refreshPayload.ExpiredAt ExpiresAt: refreshPayload.ExpiredAt,
}
session, err := model.CreateSession(&createSessionParams) session, err := model.CreateSession(&createSessionParams)
if err != nil { if err != nil {
@ -142,3 +156,11 @@ func (server *Server) listUsers(ctx *gin.Context) {
func (server *Server) getUser(ctx *gin.Context) { func (server *Server) getUser(ctx *gin.Context) {
} }
func (server *Server) IsUserLoggedIn(username string) *model.Session {
session := model.GetSessionByUsername(username)
if session != nil {
return session
}
return nil
}

View File

@ -12,7 +12,7 @@ var db *sqlx.DB
func InitDB() { func InitDB() {
var err error var err error
dsn := "root:keji178@tcp(login-test.kingsome.cn:3306)/admindb_dev?charset=utf8mb4&parseTime=True" dsn := "root:keji178@tcp(login-test.kingsome.cn:3306)/admindb_dev?charset=utf8mb4&parseTime=True&loc=Local"
db, err = sqlx.Connect("mysql", dsn) db, err = sqlx.Connect("mysql", dsn)
if err != nil { if err != nil {
f5.GetSysLog().Info("Failed to connect to the database err:%v \n", err) f5.GetSysLog().Info("Failed to connect to the database err:%v \n", err)

View File

@ -1,7 +1,7 @@
package model package model
import ( import (
"log" "database/sql"
"main/db" "main/db"
"time" "time"
) )
@ -10,22 +10,34 @@ type Session struct {
Id int64 `db:"id"` Id int64 `db:"id"`
Username string `db:"username"` Username string `db:"username"`
RefreshToken string `db:"refresh_token"` RefreshToken string `db:"refresh_token"`
UserAgent string `db:"user_agent"`
ClientIp string `db:"client_ip"` ClientIp string `db:"client_ip"`
UserAgent string `db:"user_agent"`
ExpiresAt time.Time `db:"expires_at"` ExpiresAt time.Time `db:"expires_at"`
} }
func CreateSession(session *Session) (*Session, error) { func CreateSession(session *Session) (*Session, error) {
query := "INSERT INTO t_sessions (id, username, refresh_token, client_ip, user_agent, expires_at) VALUES (?, ?, ?, ?, ?, ?)" query := "REPLACE INTO t_sessions (id, username, refresh_token, client_ip, user_agent, expires_at) VALUES (?, ?, ?, ?, ?, ?)"
result, err := db.GetDB().Exec(query, session.Id, session.Username, session.RefreshToken, session.UserAgent, session.ClientIp, session.ExpiresAt) result, err := db.GetDB().Exec(query, session.Id, session.Username, session.RefreshToken, session.ClientIp, session.UserAgent, session.ExpiresAt)
if err != nil { if err != nil {
log.Printf("Error creating session: %v", err)
return nil, err return nil, err
} }
_, err = result.LastInsertId() _, err = result.LastInsertId()
if err != nil { if err != nil {
log.Printf("Error getting last insert ID: %v", err)
return nil, err return nil, err
} }
return session, nil return session, nil
} }
func GetSessionByUsername(username string) *Session {
query := "SELECT id, username, refresh_token, client_ip, user_agent, expires_at FROM t_sessions WHERE username = ?"
var session Session
err := db.GetDB().Get(&session, query, username)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
return nil
}
return &session
}