save
This commit is contained in:
parent
52dbd50f18
commit
18ba5355ae
@ -51,7 +51,7 @@ refresh_token varchar(255) not null default "",
|
|||||||
client_ip varchar(255) not null default "",
|
client_ip varchar(255) not null default "",
|
||||||
user_agent varchar(255) NOT NULL default "",
|
user_agent varchar(255) NOT NULL default "",
|
||||||
is_blocked tinyint not null default "0",
|
is_blocked tinyint not null default "0",
|
||||||
expires_at int not null default "0",
|
expires_at datetime not null,
|
||||||
created_at int not null default "0",
|
created_at int not null default "0",
|
||||||
primary key id(id),
|
primary key id(id),
|
||||||
unique key username(username)
|
unique key username(username)
|
||||||
|
@ -26,6 +26,7 @@ go get -u github.com/o1egl/paseto
|
|||||||
var TokenSymmetricKey = "12345678901234567890123456789012"
|
var TokenSymmetricKey = "12345678901234567890123456789012"
|
||||||
|
|
||||||
func NewServer(store *sqlx.DB) (*Server, error) {
|
func NewServer(store *sqlx.DB) (*Server, error) {
|
||||||
|
// NewPasetoMaker or NewJWTMaker, First use NewPasetoMaker
|
||||||
tokenMaker, err := token.NewPasetoMaker(TokenSymmetricKey)
|
tokenMaker, err := token.NewPasetoMaker(TokenSymmetricKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot create token maker: %w", err)
|
return nil, fmt.Errorf("cannot create token maker: %w", err)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"main/model"
|
"main/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -76,11 +77,23 @@ func (server *Server) loginUser(ctx *gin.Context) {
|
|||||||
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
|
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !CheckPassword(req.Password, user.Password) {
|
if !CheckPassword(req.Password, user.Password) {
|
||||||
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
|
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查用户是否已经登录,
|
||||||
|
// 1. 更新会话信息, 2. 终止旧会话, 3. 阻止新登录
|
||||||
|
hasSession := server.IsUserLoggedIn(req.Username)
|
||||||
|
if hasSession != nil {
|
||||||
|
if hasSession.ExpiresAt.After(time.Now()) {
|
||||||
|
err = fmt.Errorf("session not expired")
|
||||||
|
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
accessTokenDuration, _ := time.ParseDuration("15m")
|
accessTokenDuration, _ := time.ParseDuration("15m")
|
||||||
accessToken, accessPayload, err := server.tokenMaker.CreateToken(
|
accessToken, accessPayload, err := server.tokenMaker.CreateToken(
|
||||||
user.Username,
|
user.Username,
|
||||||
@ -101,13 +114,14 @@ func (server *Server) loginUser(ctx *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var createSessionParams model.Session
|
createSessionParams := model.Session{
|
||||||
createSessionParams.Id = refreshPayload.ID
|
Id: refreshPayload.ID,
|
||||||
createSessionParams.Username = req.Username
|
Username: refreshPayload.Username,
|
||||||
createSessionParams.RefreshToken = refreshToken
|
RefreshToken: refreshToken,
|
||||||
createSessionParams.UserAgent = ctx.Request.UserAgent()
|
ClientIp: ctx.ClientIP(),
|
||||||
createSessionParams.ClientIp = ctx.ClientIP()
|
UserAgent: ctx.Request.UserAgent(),
|
||||||
createSessionParams.ExpiresAt = refreshPayload.ExpiredAt
|
ExpiresAt: refreshPayload.ExpiredAt,
|
||||||
|
}
|
||||||
|
|
||||||
session, err := model.CreateSession(&createSessionParams)
|
session, err := model.CreateSession(&createSessionParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -142,3 +156,11 @@ func (server *Server) listUsers(ctx *gin.Context) {
|
|||||||
func (server *Server) getUser(ctx *gin.Context) {
|
func (server *Server) getUser(ctx *gin.Context) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (server *Server) IsUserLoggedIn(username string) *model.Session {
|
||||||
|
session := model.GetSessionByUsername(username)
|
||||||
|
if session != nil {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -12,7 +12,7 @@ var db *sqlx.DB
|
|||||||
|
|
||||||
func InitDB() {
|
func InitDB() {
|
||||||
var err error
|
var err error
|
||||||
dsn := "root:keji178@tcp(login-test.kingsome.cn:3306)/admindb_dev?charset=utf8mb4&parseTime=True"
|
dsn := "root:keji178@tcp(login-test.kingsome.cn:3306)/admindb_dev?charset=utf8mb4&parseTime=True&loc=Local"
|
||||||
db, err = sqlx.Connect("mysql", dsn)
|
db, err = sqlx.Connect("mysql", dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f5.GetSysLog().Info("Failed to connect to the database err:%v \n", err)
|
f5.GetSysLog().Info("Failed to connect to the database err:%v \n", err)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"database/sql"
|
||||||
"main/db"
|
"main/db"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -10,22 +10,34 @@ type Session struct {
|
|||||||
Id int64 `db:"id"`
|
Id int64 `db:"id"`
|
||||||
Username string `db:"username"`
|
Username string `db:"username"`
|
||||||
RefreshToken string `db:"refresh_token"`
|
RefreshToken string `db:"refresh_token"`
|
||||||
UserAgent string `db:"user_agent"`
|
|
||||||
ClientIp string `db:"client_ip"`
|
ClientIp string `db:"client_ip"`
|
||||||
|
UserAgent string `db:"user_agent"`
|
||||||
ExpiresAt time.Time `db:"expires_at"`
|
ExpiresAt time.Time `db:"expires_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateSession(session *Session) (*Session, error) {
|
func CreateSession(session *Session) (*Session, error) {
|
||||||
query := "INSERT INTO t_sessions (id, username, refresh_token, client_ip, user_agent, expires_at) VALUES (?, ?, ?, ?, ?, ?)"
|
query := "REPLACE INTO t_sessions (id, username, refresh_token, client_ip, user_agent, expires_at) VALUES (?, ?, ?, ?, ?, ?)"
|
||||||
result, err := db.GetDB().Exec(query, session.Id, session.Username, session.RefreshToken, session.UserAgent, session.ClientIp, session.ExpiresAt)
|
result, err := db.GetDB().Exec(query, session.Id, session.Username, session.RefreshToken, session.ClientIp, session.UserAgent, session.ExpiresAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error creating session: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
_, err = result.LastInsertId()
|
_, err = result.LastInsertId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error getting last insert ID: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetSessionByUsername(username string) *Session {
|
||||||
|
query := "SELECT id, username, refresh_token, client_ip, user_agent, expires_at FROM t_sessions WHERE username = ?"
|
||||||
|
var session Session
|
||||||
|
err := db.GetDB().Get(&session, query, username)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user