const pool = require('../config/database'); const crypto = require('crypto'); class OAuthClient { // 创建OAuth客户端表 static async createTable() { const query = ` CREATE TABLE IF NOT EXISTS oauth_clients ( id SERIAL PRIMARY KEY, client_id VARCHAR(100) UNIQUE NOT NULL, client_secret VARCHAR(255) NOT NULL, name VARCHAR(100) NOT NULL, description TEXT, redirect_uris TEXT[] NOT NULL, scopes TEXT[] DEFAULT ARRAY['read', 'write'], user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, is_active BOOLEAN DEFAULT true, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); `; try { await pool.query(query); console.log('OAuth客户端表创建成功'); } catch (error) { console.error('创建OAuth客户端表失败:', error); throw error; } } // 创建授权码表 static async createAuthCodeTable() { const query = ` CREATE TABLE IF NOT EXISTS oauth_auth_codes ( id SERIAL PRIMARY KEY, code VARCHAR(100) UNIQUE NOT NULL, client_id VARCHAR(100) NOT NULL, user_id INTEGER NOT NULL, redirect_uri VARCHAR(255) NOT NULL, scopes TEXT[], expires_at TIMESTAMP WITH TIME ZONE NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); `; try { await pool.query(query); console.log('OAuth授权码表创建成功'); } catch (error) { console.error('创建OAuth授权码表失败:', error); throw error; } } // 创建访问令牌表 static async createAccessTokenTable() { const query = ` CREATE TABLE IF NOT EXISTS oauth_access_tokens ( id SERIAL PRIMARY KEY, token VARCHAR(255) UNIQUE NOT NULL, client_id VARCHAR(100) NOT NULL, user_id INTEGER NOT NULL, scopes TEXT[], expires_at TIMESTAMP WITH TIME ZONE NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); `; try { await pool.query(query); console.log('OAuth访问令牌表创建成功'); } catch (error) { console.error('创建OAuth访问令牌表失败:', error); throw error; } } // 创建刷新令牌表 static async createRefreshTokenTable() { const query = ` CREATE TABLE IF NOT EXISTS oauth_refresh_tokens ( id SERIAL PRIMARY KEY, token VARCHAR(255) UNIQUE NOT NULL, access_token_id INTEGER REFERENCES oauth_access_tokens(id) ON DELETE CASCADE, client_id VARCHAR(100) NOT NULL, user_id INTEGER NOT NULL, is_revoked BOOLEAN DEFAULT FALSE, expires_at TIMESTAMP WITH TIME ZONE NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); `; try { await pool.query(query); console.log('OAuth刷新令牌表创建成功'); } catch (error) { console.error('创建OAuth刷新令牌表失败:', error); throw error; } } // 生成客户端ID和密钥 static generateClientCredentials() { const clientId = crypto.randomBytes(32).toString('hex'); const clientSecret = crypto.randomBytes(64).toString('hex'); return { clientId, clientSecret }; } // 创建新的OAuth客户端 static async create(clientData) { const { name, description, redirectUris, scopes, userId } = clientData; const { clientId, clientSecret } = this.generateClientCredentials(); const query = ` INSERT INTO oauth_clients (client_id, client_secret, name, description, redirect_uris, scopes, user_id) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id, client_id, client_secret, name, description, redirect_uris, scopes, created_at `; try { const result = await pool.query(query, [ clientId, clientSecret, name, description, redirectUris, scopes || ['read', 'write'], userId ]); return result.rows[0]; } catch (error) { console.error('创建OAuth客户端失败:', error); throw error; } } // 重置客户端密钥 static async resetClientSecret(clientId) { const newSecret = crypto.randomBytes(64).toString('hex'); const query = 'UPDATE oauth_clients SET client_secret = $1 WHERE client_id = $2 AND is_active = true'; try { const result = await pool.query(query, [newSecret, clientId]); if (result.rowCount === 0) { throw new Error('客户端不存在或已禁用'); } return newSecret; } catch (error) { console.error('重置客户端密钥失败:', error); throw error; } } // 根据客户端ID查找客户端 static async findByClientId(clientId) { const query = 'SELECT * FROM oauth_clients WHERE client_id = $1 AND is_active = true'; try { const result = await pool.query(query, [clientId]); return result.rows[0]; } catch (error) { console.error('查找OAuth客户端失败:', error); throw error; } } // 验证客户端密钥 static async validateClient(clientId, clientSecret) { const client = await this.findByClientId(clientId); if (!client) return false; return client.client_secret === clientSecret; } // 验证重定向URI static async validateRedirectUri(clientId, redirectUri) { const client = await this.findByClientId(clientId); if (!client) return false; // console.log(client.redirect_uris) console.log("请求 redirect_uri: ", redirectUri); console.log("允许 redirect_uris: ", client.redirect_uris); console.log("是否包含: ", client.redirect_uris.includes(redirectUri)); return client.redirect_uris.includes(redirectUri); } // 获取用户的所有客户端 static async findByUserId(userId) { const query = 'SELECT id, client_id, name, description, redirect_uris, scopes, created_at FROM oauth_clients WHERE user_id = $1 AND is_active = true'; try { const result = await pool.query(query, [userId]); return result.rows; } catch (error) { console.error('查找用户客户端失败:', error); throw error; } } // 删除客户端 static async delete(clientId, userId) { const query = 'UPDATE oauth_clients SET is_active = false WHERE client_id = $1 AND user_id = $2'; try { const result = await pool.query(query, [clientId, userId]); return result.rowCount > 0; } catch (error) { console.error('删除OAuth客户端失败:', error); throw error; } } } module.exports = OAuthClient;