211 lines
6.6 KiB
JavaScript
211 lines
6.6 KiB
JavaScript
const pool = require('../config/database');
|
|
const crypto = require('crypto');
|
|
|
|
class OAuthClient {
|
|
// 创建OAuth客户端表
|
|
static async createTable() {
|
|
const query = `
|
|
CREATE TABLE IF NOT EXISTS oauth_clients (
|
|
id SERIAL PRIMARY KEY,
|
|
client_id VARCHAR(100) UNIQUE NOT NULL,
|
|
client_secret VARCHAR(255) NOT NULL,
|
|
name VARCHAR(100) NOT NULL,
|
|
description TEXT,
|
|
redirect_uris TEXT[] NOT NULL,
|
|
scopes TEXT[] DEFAULT ARRAY['read', 'write'],
|
|
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
|
|
is_active BOOLEAN DEFAULT true,
|
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
`;
|
|
|
|
try {
|
|
await pool.query(query);
|
|
console.log('OAuth客户端表创建成功');
|
|
} catch (error) {
|
|
console.error('创建OAuth客户端表失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 创建授权码表
|
|
static async createAuthCodeTable() {
|
|
const query = `
|
|
CREATE TABLE IF NOT EXISTS oauth_auth_codes (
|
|
id SERIAL PRIMARY KEY,
|
|
code VARCHAR(100) UNIQUE NOT NULL,
|
|
client_id VARCHAR(100) NOT NULL,
|
|
user_id INTEGER NOT NULL,
|
|
redirect_uri VARCHAR(255) NOT NULL,
|
|
scopes TEXT[],
|
|
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
`;
|
|
|
|
try {
|
|
await pool.query(query);
|
|
console.log('OAuth授权码表创建成功');
|
|
} catch (error) {
|
|
console.error('创建OAuth授权码表失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 创建访问令牌表
|
|
static async createAccessTokenTable() {
|
|
const query = `
|
|
CREATE TABLE IF NOT EXISTS oauth_access_tokens (
|
|
id SERIAL PRIMARY KEY,
|
|
token VARCHAR(255) UNIQUE NOT NULL,
|
|
client_id VARCHAR(100) NOT NULL,
|
|
user_id INTEGER NOT NULL,
|
|
scopes TEXT[],
|
|
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
`;
|
|
|
|
try {
|
|
await pool.query(query);
|
|
console.log('OAuth访问令牌表创建成功');
|
|
} catch (error) {
|
|
console.error('创建OAuth访问令牌表失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 创建刷新令牌表
|
|
static async createRefreshTokenTable() {
|
|
const query = `
|
|
CREATE TABLE IF NOT EXISTS oauth_refresh_tokens (
|
|
id SERIAL PRIMARY KEY,
|
|
token VARCHAR(255) UNIQUE NOT NULL,
|
|
access_token_id INTEGER REFERENCES oauth_access_tokens(id) ON DELETE CASCADE,
|
|
client_id VARCHAR(100) NOT NULL,
|
|
user_id INTEGER NOT NULL,
|
|
is_revoked BOOLEAN DEFAULT FALSE,
|
|
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
`;
|
|
|
|
try {
|
|
await pool.query(query);
|
|
console.log('OAuth刷新令牌表创建成功');
|
|
} catch (error) {
|
|
console.error('创建OAuth刷新令牌表失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 生成客户端ID和密钥
|
|
static generateClientCredentials() {
|
|
const clientId = crypto.randomBytes(32).toString('hex');
|
|
const clientSecret = crypto.randomBytes(64).toString('hex');
|
|
return { clientId, clientSecret };
|
|
}
|
|
|
|
// 创建新的OAuth客户端
|
|
static async create(clientData) {
|
|
const { name, description, redirectUris, scopes, userId } = clientData;
|
|
const { clientId, clientSecret } = this.generateClientCredentials();
|
|
|
|
const query = `
|
|
INSERT INTO oauth_clients (client_id, client_secret, name, description, redirect_uris, scopes, user_id)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
RETURNING id, client_id, client_secret, name, description, redirect_uris, scopes, created_at
|
|
`;
|
|
|
|
try {
|
|
const result = await pool.query(query, [
|
|
clientId,
|
|
clientSecret,
|
|
name,
|
|
description,
|
|
redirectUris,
|
|
scopes || ['read', 'write'],
|
|
userId
|
|
]);
|
|
return result.rows[0];
|
|
} catch (error) {
|
|
console.error('创建OAuth客户端失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 重置客户端密钥
|
|
static async resetClientSecret(clientId) {
|
|
const newSecret = crypto.randomBytes(64).toString('hex');
|
|
const query = 'UPDATE oauth_clients SET client_secret = $1 WHERE client_id = $2 AND is_active = true';
|
|
|
|
try {
|
|
const result = await pool.query(query, [newSecret, clientId]);
|
|
if (result.rowCount === 0) {
|
|
throw new Error('客户端不存在或已禁用');
|
|
}
|
|
return newSecret;
|
|
} catch (error) {
|
|
console.error('重置客户端密钥失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 根据客户端ID查找客户端
|
|
static async findByClientId(clientId) {
|
|
const query = 'SELECT * FROM oauth_clients WHERE client_id = $1 AND is_active = true';
|
|
try {
|
|
const result = await pool.query(query, [clientId]);
|
|
return result.rows[0];
|
|
} catch (error) {
|
|
console.error('查找OAuth客户端失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 验证客户端密钥
|
|
static async validateClient(clientId, clientSecret) {
|
|
const client = await this.findByClientId(clientId);
|
|
if (!client) return false;
|
|
return client.client_secret === clientSecret;
|
|
}
|
|
|
|
// 验证重定向URI
|
|
static async validateRedirectUri(clientId, redirectUri) {
|
|
const client = await this.findByClientId(clientId);
|
|
if (!client) return false;
|
|
// console.log(client.redirect_uris)
|
|
console.log("请求 redirect_uri: ", redirectUri);
|
|
console.log("允许 redirect_uris: ", client.redirect_uris);
|
|
console.log("是否包含: ", client.redirect_uris.includes(redirectUri));
|
|
|
|
return client.redirect_uris.includes(redirectUri);
|
|
}
|
|
|
|
// 获取用户的所有客户端
|
|
static async findByUserId(userId) {
|
|
const query = 'SELECT id, client_id, name, description, redirect_uris, scopes, created_at FROM oauth_clients WHERE user_id = $1 AND is_active = true';
|
|
try {
|
|
const result = await pool.query(query, [userId]);
|
|
return result.rows;
|
|
} catch (error) {
|
|
console.error('查找用户客户端失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// 删除客户端
|
|
static async delete(clientId, userId) {
|
|
const query = 'UPDATE oauth_clients SET is_active = false WHERE client_id = $1 AND user_id = $2';
|
|
try {
|
|
const result = await pool.query(query, [clientId, userId]);
|
|
return result.rowCount > 0;
|
|
} catch (error) {
|
|
console.error('删除OAuth客户端失败:', error);
|
|
throw error;
|
|
}
|
|
}
|
|
}
|
|
|
|
module.exports = OAuthClient; |