diff --git a/src/controllers/discord.controller.ts b/src/controllers/discord.controller.ts index 8d2fee7..38d6a23 100644 --- a/src/controllers/discord.controller.ts +++ b/src/controllers/discord.controller.ts @@ -4,6 +4,7 @@ import { role, router } from 'decorators/router' import logger from 'logger/logger' import { AuthRecord } from 'modules/AuthRecord' import { DiscordSvr, exchangeDiscrodCodeForToken, userInfo } from 'services/discord.svr' +import { parseOauthState } from 'utils/net.util' class DiscordController extends BaseController { @role(ROLE_ANON) @@ -11,7 +12,7 @@ class DiscordController extends BaseController { async discordCallback(req, res) { let { code, state } = req.params if (code && state) { - const stateArr = state.split('|') + const stateArr = parseOauthState(state) const address = stateArr[0].toLowerCase() const record = await AuthRecord.insertOrUpdate( { address, platform: 7 }, @@ -22,17 +23,26 @@ class DiscordController extends BaseController { record.refreshToken = tokenResponse.refresh_token record.scope = tokenResponse.scope record.tokenType = tokenResponse.token_type - record.expiresIn = tokenResponse.expires_in + Date.now() + record.expiresIn = tokenResponse.expires_in || 0 + Date.now() await record.save() - let uinfo = await userInfo(tokenResponse.access_token) - record.nickname = uinfo.username - record.username = uinfo.username - record.discriminator = uinfo.discriminator - record.openId = uinfo.id - await record.save() - return res.view('/templates/discord_redirect.ejs') - } else { - return res.view('/templates/discord_redirect.ejs') + if (tokenResponse && tokenResponse.access_token) { + let uinfo = await userInfo(tokenResponse.access_token) + record.nickname = uinfo.username + record.username = uinfo.username + record.discriminator = uinfo.discriminator + record.openId = uinfo.id + await record.save() + } + if (stateArr.length > 2) { + return res.redirect(stateArr[2]) + } } + if (state) { + const stateArr = parseOauthState(state) + if (stateArr.length > 2) { + return res.redirect(stateArr[2]) + } + } + return res.view('/templates/discord_redirect.ejs') } } diff --git a/src/controllers/twitter.controller.ts b/src/controllers/twitter.controller.ts index 0c1337f..db58850 100644 --- a/src/controllers/twitter.controller.ts +++ b/src/controllers/twitter.controller.ts @@ -1,9 +1,9 @@ import BaseController, { ROLE_ANON } from 'common/base.controller' -import { ZError } from 'common/ZError' import { role, router } from 'decorators/router' import logger from 'logger/logger' import { AuthRecord } from 'modules/AuthRecord' import { exchangeTwitterCodeForToken, getTwitterUserInfo } from 'services/twitter.svr' +import { parseOauthState } from 'utils/net.util' class TwitterController extends BaseController { @role(ROLE_ANON) @@ -12,7 +12,7 @@ class TwitterController extends BaseController { logger.info('twitter redirect: ', req.params) const { code, state } = req.params if (code && state) { - const stateArr = state.split('|') + const stateArr = parseOauthState(state) const address = stateArr[0].toLowerCase() const record = await AuthRecord.insertOrUpdate( { address, platform: 4 }, @@ -24,13 +24,24 @@ class TwitterController extends BaseController { record.refreshToken = tokenResponse.refresh_token record.scope = tokenResponse.scope record.tokenType = tokenResponse.token_type - record.expiresIn = tokenResponse.expires_in + Date.now() - await record.save() - const uinfo = await getTwitterUserInfo(tokenResponse.access_token) - record.nickname = uinfo.data.name - record.username = uinfo.data.username - record.openId = uinfo.data.id + record.expiresIn = tokenResponse.expires_in || 0 + Date.now() await record.save() + if (tokenResponse && tokenResponse.access_token) { + const uinfo = await getTwitterUserInfo(tokenResponse.access_token) + record.nickname = uinfo.data.name + record.username = uinfo.data.username + record.openId = uinfo.data.id + await record.save() + } + if (stateArr.length > 2) { + return res.redirect(stateArr[2]) + } + } + if (state) { + const stateArr = parseOauthState(state) + if (stateArr.length > 2) { + return res.redirect(stateArr[2]) + } } return res.view('/templates/twitter_redirect.ejs') } diff --git a/src/utils/net.util.ts b/src/utils/net.util.ts index 456606b..2733ab3 100644 --- a/src/utils/net.util.ts +++ b/src/utils/net.util.ts @@ -128,38 +128,38 @@ export function generateKVStr({ sort = false, encode = false, ignoreNull = true, - splitChar = "&", - equalChar = "=", - uri = "", + splitChar = '&', + equalChar = '=', + uri = '', }: { - data?: any; - sort?: boolean; - encode?: boolean; - ignoreNull?: boolean; - splitChar?: string; - equalChar?: string; - uri?: string; + data?: any + sort?: boolean + encode?: boolean + ignoreNull?: boolean + splitChar?: string + equalChar?: string + uri?: string }) { - const keys = Object.keys(data); - sort && keys.sort(); - let result = ""; - let i = 0; + const keys = Object.keys(data) + sort && keys.sort() + let result = '' + let i = 0 for (let key of keys) { if (ignoreNull && !data[key]) { - continue; + continue } - if (i++ > 0) result += splitChar; + if (i++ > 0) result += splitChar if (encode) { - result += `${key}${equalChar}${encodeURIComponent(data[key])}`; + result += `${key}${equalChar}${encodeURIComponent(data[key])}` } else { - result += `${key}${equalChar}${data[key]}`; + result += `${key}${equalChar}${data[key]}` } } if (uri) { - const joinChar = uri.search(/\?/) === -1 ? "?" : "&"; - result = uri + joinChar + result; + const joinChar = uri.search(/\?/) === -1 ? '?' : '&' + result = uri + joinChar + result } - return result; + return result } /** @@ -168,19 +168,23 @@ export function generateKVStr({ * @param splitChar 连接的字符, 默认是& * @param equalChar = */ -export function keyValToObject( - str: string, - splitChar: string = "&", - equalChar = "=" -): {} { - let result: any = {}; +export function keyValToObject(str: string, splitChar: string = '&', equalChar = '='): {} { + let result: any = {} if (!str) { - return result; + return result } - let arrs = str.split(splitChar); + let arrs = str.split(splitChar) for (let sub of arrs) { - let subArr = sub.split(equalChar); - result[subArr[0]] = subArr[1]; + let subArr = sub.split(equalChar) + result[subArr[0]] = subArr[1] + } + return result +} + +export function parseOauthState(state: string) { + if (state.startsWith('0x')) { + return state.split('|') + } else { + return atob(state).split('|') } - return result; } diff --git a/templates/twitter_redirect.ejs b/templates/twitter_redirect.ejs index fdd9afc..ca3146d 100644 --- a/templates/twitter_redirect.ejs +++ b/templates/twitter_redirect.ejs @@ -15,6 +15,8 @@