refactor into multipurpose
This commit is contained in:
parent
d650a7496f
commit
6578cca96d
8 changed files with 443 additions and 374 deletions
7
Cargo.lock
generated
7
Cargo.lock
generated
|
@ -86,6 +86,12 @@ dependencies = [
|
|||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.97"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.86"
|
||||
|
@ -1161,6 +1167,7 @@ dependencies = [
|
|||
name = "noah"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"clap",
|
||||
"pgvector",
|
||||
|
|
11
Cargo.toml
11
Cargo.toml
|
@ -4,9 +4,10 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["postgres", "search"]
|
||||
postgres = ["sqlx/postgres", "pgvector"]
|
||||
search = ["postgres"]
|
||||
default = []
|
||||
relay = []
|
||||
discovery = ["pgvector"]
|
||||
search = ["discovery"]
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.3", features = ["derive"] }
|
||||
|
@ -15,10 +16,12 @@ tokio = { version = "1.0", features = ["full"] }
|
|||
sqlx = { version = "0.7", features = [
|
||||
"runtime-tokio-native-tls",
|
||||
"json",
|
||||
], optional = true }
|
||||
"postgres",
|
||||
] }
|
||||
pgvector = { version = "0.3", features = ["sqlx"], optional = true }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
anyhow = "1.0.97"
|
||||
|
|
17
src/db.rs
Normal file
17
src/db.rs
Normal file
|
@ -0,0 +1,17 @@
|
|||
use sqlx::{postgres::PgPoolOptions, PgPool, migrate::Migrator};
|
||||
use anyhow::Result;
|
||||
|
||||
//static MIGRATOR: Migrator = sqlx::migrate!();
|
||||
|
||||
pub type Db = PgPool;
|
||||
|
||||
pub async fn initDb(connection_string: &str) -> Result<Db> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&connection_string)
|
||||
.await?;
|
||||
|
||||
//MIGRATOR.run(&pool).await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
54
src/discovery/embedding.rs
Normal file
54
src/discovery/embedding.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use crate::App;
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub type Embedding = Vec<f32>;
|
||||
|
||||
pub mod cloudflare {
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingContext<'a> {
|
||||
text: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest<'a> {
|
||||
query: Option<&'a str>, // honestly we are probably never going to use this field
|
||||
contexts: Vec<CloudflareEmbeddingContext<'a>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResult {
|
||||
response: Vec<Embedding>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
result: CloudflareEmbeddingResult,
|
||||
success: bool,
|
||||
errors: Vec<String>,
|
||||
messages: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn generate_embedding(app: &App, text: &str) -> Result<Embedding> {
|
||||
let response = app
|
||||
.web_client
|
||||
.post(format!(
|
||||
"https://api.cloudflare.com/client/v4/accounts/{}/ai/run/@cf/baai/bge-m3",
|
||||
app.cloudflare_account_id
|
||||
))
|
||||
.header(
|
||||
"Authorization",
|
||||
format!("Bearer {}", app.cloudflare_api_key),
|
||||
)
|
||||
.json(&EmbeddingRequest {
|
||||
query: None,
|
||||
contexts: vec![EmbeddingContext { text: text }],
|
||||
})
|
||||
.send()
|
||||
.await?
|
||||
.json::<EmbeddingResponse>()
|
||||
.await?;
|
||||
|
||||
Ok(response.result.response[0].clone())
|
||||
}
|
||||
}
|
3
src/discovery/mod.rs
Normal file
3
src/discovery/mod.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
pub(super) mod embedding;
|
||||
#[cfg(feature = "search")]
|
||||
pub mod search;
|
241
src/discovery/search.rs
Normal file
241
src/discovery/search.rs
Normal file
|
@ -0,0 +1,241 @@
|
|||
use crate::{App, embedding::generate_embedding};
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{QueryBuilder, Row};
|
||||
use pgvector::Vector;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct TagFilters {
|
||||
exact: Option<serde_json::Value>,
|
||||
any: Option<Vec<String>>,
|
||||
values: Option<TagValueFilter>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct TagValueFilter {
|
||||
key: String,
|
||||
values: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct SearchQuery {
|
||||
query: String,
|
||||
limit: Option<i64>,
|
||||
offset: Option<i64>,
|
||||
filters: Option<SearchFilters>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct SearchFilters {
|
||||
pubkey: Option<String>,
|
||||
kind: Option<i32>,
|
||||
tags: Option<TagFilters>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResult {
|
||||
results: Vec<EventWithSimilarity>,
|
||||
total: i64,
|
||||
limit: i64,
|
||||
offset: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
struct EventWithSimilarity {
|
||||
id: String,
|
||||
pubkey: String,
|
||||
created_at: i64,
|
||||
kind: i32,
|
||||
content: String,
|
||||
tags: serde_json::Value,
|
||||
similarity: f64,
|
||||
}
|
||||
|
||||
pub async fn search_events(app: &App, search_query: SearchQuery) -> Result<SearchResult> {
|
||||
let limit = search_query.limit.unwrap_or(10);
|
||||
let offset = search_query.offset.unwrap_or(0);
|
||||
|
||||
let embedding_vec = generate_embedding(&app, &search_query.query).await?;
|
||||
|
||||
let embedding = Vector::from(embedding_vec);
|
||||
|
||||
// Start building the base query
|
||||
let mut qb = QueryBuilder::new(
|
||||
"SELECT id, pubkey, created_at, kind, content, tags, 1 - (embedding <=> ",
|
||||
);
|
||||
|
||||
qb.push_bind(embedding.clone())
|
||||
.push(") as similarity FROM nostr_search.events");
|
||||
|
||||
// Add WHERE clause if we have filters
|
||||
if let Some(filters) = search_query.clone().filters {
|
||||
let mut first_condition = true;
|
||||
|
||||
if let Some(pubkey) = filters.pubkey {
|
||||
qb.push(" WHERE pubkey = ");
|
||||
qb.push_bind(pubkey);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(kind) = filters.kind {
|
||||
if first_condition {
|
||||
qb.push(" WHERE ");
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
}
|
||||
qb.push("kind = ");
|
||||
qb.push_bind(kind);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(tag_filters) = filters.tags {
|
||||
if let Some(exact) = tag_filters.exact {
|
||||
if first_condition {
|
||||
qb.push(" WHERE ");
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
}
|
||||
qb.push("tags @> ");
|
||||
qb.push_bind(exact);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add ordering, limit and offset
|
||||
qb.push(" ORDER BY embedding <=> ")
|
||||
.push_bind(embedding)
|
||||
.push(" LIMIT ")
|
||||
.push_bind(limit)
|
||||
.push(" OFFSET ")
|
||||
.push_bind(offset);
|
||||
|
||||
// Build and execute the query
|
||||
let query = qb.build_query_as::<EventWithSimilarity>();
|
||||
|
||||
let results = query.fetch_all(&app.db).await?;
|
||||
|
||||
// Build the count query
|
||||
let mut count_qb = QueryBuilder::new("SELECT COUNT(*) FROM nostr_search.events");
|
||||
|
||||
if let Some(filters) = search_query.filters {
|
||||
let mut first_condition = true;
|
||||
|
||||
if let Some(pubkey) = filters.pubkey {
|
||||
count_qb.push(" WHERE pubkey = ");
|
||||
count_qb.push_bind(pubkey);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(kind) = filters.kind {
|
||||
if first_condition {
|
||||
count_qb.push(" WHERE ");
|
||||
} else {
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
count_qb.push("kind = ");
|
||||
count_qb.push_bind(kind);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(tag_filters) = filters.tags {
|
||||
if let Some(exact) = tag_filters.exact {
|
||||
if first_condition {
|
||||
count_qb.push(" WHERE ");
|
||||
} else {
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
count_qb.push("tags @> ");
|
||||
count_qb.push_bind(exact);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total: i64 = count_qb
|
||||
.build()
|
||||
.fetch_one(&app.db)
|
||||
.await?
|
||||
.get(0);
|
||||
|
||||
Ok(SearchResult {
|
||||
results,
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
})
|
||||
}
|
||||
|
||||
//pub async fn get_similar_events(
|
||||
// State(state): State<Arc<App>>,
|
||||
// Path(event_id): Path<String>,
|
||||
// Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
//) -> Result<Json<Vec<EventWithSimilarity>>, String> {
|
||||
// let limit = params
|
||||
// .get("limit")
|
||||
// .and_then(|l| l.parse::<i64>().ok())
|
||||
// .unwrap_or(5);
|
||||
//
|
||||
// let query = sqlx::query_as::<_, EventWithSimilarity>(
|
||||
// "WITH event_embedding AS (
|
||||
// SELECT embedding
|
||||
// FROM nostr_search.events
|
||||
// WHERE id = $1
|
||||
// )
|
||||
// SELECT
|
||||
// ne.id,
|
||||
// ne.pubkey,
|
||||
// ne.created_at,
|
||||
// ne.kind,
|
||||
// ne.content,
|
||||
// ne.tags,
|
||||
// 1 - (ne.embedding <=> e.embedding) as similarity
|
||||
// FROM nostr_search.events ne, event_embedding e
|
||||
// WHERE ne.id != $1
|
||||
// ORDER BY ne.embedding <=> e.embedding
|
||||
// LIMIT $2",
|
||||
// )
|
||||
// .bind(event_id)
|
||||
// .bind(limit);
|
||||
//
|
||||
// let similar_events = query
|
||||
// .fetch_all(&state.pool)
|
||||
// .await
|
||||
// .map_err(|e| e.to_string())?;
|
||||
//
|
||||
// Ok(Json(similar_events))
|
||||
//}
|
||||
//
|
||||
//pub async fn get_tag_values(
|
||||
// State(state): State<Arc<App>>,
|
||||
// Path(tag_key): Path<String>,
|
||||
// Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
//) -> Result<Json<Vec<TagValue>>, String> {
|
||||
// let limit = params
|
||||
// .get("limit")
|
||||
// .and_then(|l| l.parse::<i64>().ok())
|
||||
// .unwrap_or(100);
|
||||
//
|
||||
// let query = sqlx::query_as::<_, TagValue>(
|
||||
// "SELECT DISTINCT tag->>'value' as value, COUNT(*) as count
|
||||
// FROM nostr_search.events,
|
||||
// jsonb_array_elements(tags) tag
|
||||
// WHERE tag->>'key' = $1
|
||||
// GROUP BY tag->>'value'
|
||||
// ORDER BY count DESC
|
||||
// LIMIT $2",
|
||||
// )
|
||||
// .bind(tag_key)
|
||||
// .bind(limit);
|
||||
//
|
||||
// let values = query
|
||||
// .fetch_all(&state.pool)
|
||||
// .await
|
||||
// .map_err(|e| e.to_string())?;
|
||||
//
|
||||
// Ok(Json(values))
|
||||
//}
|
||||
//
|
||||
//#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
//struct TagValue {
|
||||
// value: String,
|
||||
// count: i64,
|
||||
//}
|
162
src/main.rs
162
src/main.rs
|
@ -1,14 +1,19 @@
|
|||
use clap::{Parser, Subcommand};
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{post, get},
|
||||
Json, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::error::Error;
|
||||
use std::io::{self, BufRead, Write};
|
||||
use sqlx::migrate::Migrator;
|
||||
use sqlx::{postgres::PgPoolOptions, PgPool};
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber;
|
||||
|
||||
mod relay;
|
||||
#[cfg(feature = "search")]
|
||||
mod search;
|
||||
pub mod db;
|
||||
#[cfg(feature = "discovery")]
|
||||
pub mod discovery;
|
||||
|
||||
// Types
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -18,7 +23,32 @@ struct NostrEvent {
|
|||
created_at: i64,
|
||||
kind: i32,
|
||||
content: String,
|
||||
tags: Value,
|
||||
tags: serde_json::Value,
|
||||
}
|
||||
|
||||
// App state
|
||||
#[derive(Clone)]
|
||||
pub struct App {
|
||||
//pub db: db::Db,
|
||||
pub web_client: reqwest::Client,
|
||||
pub cloudflare_account_id: Option<String>,
|
||||
// is this a bad idea? ;)
|
||||
pub cloudflare_api_key: Option<String>,
|
||||
}
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Run the daemon
|
||||
Daemon,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -39,47 +69,31 @@ struct PluginOutput {
|
|||
msg: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Run the daemon
|
||||
Daemon,
|
||||
}
|
||||
|
||||
// App state definition - available regardless of features
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
#[cfg(feature = "postgres")]
|
||||
pool: sqlx::PgPool,
|
||||
#[cfg(feature = "search")]
|
||||
openai_client: reqwest::Client,
|
||||
#[cfg(feature = "search")]
|
||||
openai_api_key: String,
|
||||
}
|
||||
use std::io::{self, BufRead, Write};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing for logging
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&std::env::var("POSTGRES_CONNECTION")?)
|
||||
.await?;
|
||||
//let pool = db::initDb(&std::env::var("POSTGRES_CONNECTION")?).await?;
|
||||
|
||||
let cf_acc_id = match std::env::var("CLOUDFLARE_ACCOUNT_ID") {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => None,
|
||||
};
|
||||
|
||||
let cf_api_key = match std::env::var("CLOUDFLARE_API_KEY") {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => None,
|
||||
};
|
||||
|
||||
// Initialize state
|
||||
let state = Arc::new(AppState {
|
||||
#[cfg(feature = "postgres")]
|
||||
pool,
|
||||
openai_client: reqwest::Client::new(),
|
||||
openai_api_key: std::env::var("GEMINI_API_KEY")?,
|
||||
let state = Arc::new(App {
|
||||
//db: pool,
|
||||
web_client: reqwest::Client::new(),
|
||||
cloudflare_account_id: cf_acc_id,
|
||||
cloudflare_api_key: cf_api_key,
|
||||
});
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
@ -87,13 +101,65 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||
match cli.command {
|
||||
Commands::Daemon => {
|
||||
println!("Starting daemon");
|
||||
#[cfg(feature = "search")]
|
||||
search::run_webserver(state).await?;
|
||||
|
||||
#[cfg(not(feature = "search"))]
|
||||
println!("Search feature not enabled, webserver functionality unavailable");
|
||||
run_webserver(state).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn home_path(
|
||||
State(state): State<Arc<App>>
|
||||
) -> Result<&'static str, AppError> {
|
||||
Ok("Hello World")
|
||||
}
|
||||
|
||||
async fn run_webserver(state: Arc<App>) -> Result<()> {
|
||||
println!("Welcome to noah");
|
||||
println!("from Chakany Systems");
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.route("/", get(home_path));
|
||||
//.route("/api/events/:event_id/similar", get(get_similar_events))
|
||||
//.route("/api/tags/:tag_key/values", get(get_tag_values))
|
||||
|
||||
#[cfg(feature = "search")]
|
||||
let app = app.route("/api/search", post(api_search_events));
|
||||
let app = app.with_state(state);
|
||||
|
||||
// Start server
|
||||
println!("listening on port 3000");
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct AppError(anyhow::Error);
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Something went wrong: {}", self.0),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> From<E> for AppError
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn from(err: E) -> Self {
|
||||
Self(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "search")]
|
||||
async fn api_search_events(
|
||||
State(state): State<Arc<App>>,
|
||||
Json(search_query): Json<queries::SearchQuery>,
|
||||
) -> Result<Json<queries::SearchResult>, AppError> {
|
||||
Ok(Json(queries::search_events(&state, search_query).await?))
|
||||
}
|
||||
|
|
322
src/search.rs
322
src/search.rs
|
@ -1,322 +0,0 @@
|
|||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use pgvector::Vector;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{postgres::PgPoolOptions, PgPool};
|
||||
use sqlx::{QueryBuilder, Row};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// Search-specific types
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct SearchQuery {
|
||||
pub query: String,
|
||||
pub limit: Option<i64>,
|
||||
pub offset: Option<i64>,
|
||||
pub filters: Option<SearchFilters>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct SearchFilters {
|
||||
pub pubkey: Option<String>,
|
||||
pub kind: Option<i32>,
|
||||
pub tags: Option<TagFilters>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct TagFilters {
|
||||
pub exact: Option<serde_json::Value>,
|
||||
pub any: Option<Vec<String>>,
|
||||
pub values: Option<TagValueFilter>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct TagValueFilter {
|
||||
pub key: String,
|
||||
pub values: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub results: Vec<EventWithSimilarity>,
|
||||
pub total: i64,
|
||||
pub limit: i64,
|
||||
pub offset: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
pub struct EventWithSimilarity {
|
||||
pub id: String,
|
||||
pub pubkey: String,
|
||||
pub created_at: i64,
|
||||
pub kind: i32,
|
||||
pub content: String,
|
||||
pub tags: serde_json::Value,
|
||||
pub similarity: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
pub struct TagValue {
|
||||
pub value: String,
|
||||
pub count: i64,
|
||||
}
|
||||
|
||||
// Handlers
|
||||
async fn search_events(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(search_query): Json<SearchQuery>,
|
||||
) -> Result<Json<SearchResult>, String> {
|
||||
let limit = search_query.limit.unwrap_or(10);
|
||||
let offset = search_query.offset.unwrap_or(0);
|
||||
|
||||
let embedding_vec = generate_embedding(&state, &search_query.query)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let embedding = Vector::from(embedding_vec);
|
||||
|
||||
// Start building the base query
|
||||
let mut qb = QueryBuilder::new(
|
||||
"SELECT id, pubkey, created_at, kind, content, tags, 1 - (embedding <-> ",
|
||||
);
|
||||
|
||||
qb.push_bind(embedding.clone())
|
||||
.push(") as similarity FROM nostr_search.events");
|
||||
|
||||
// Add WHERE clause if we have filters
|
||||
if let Some(filters) = search_query.clone().filters {
|
||||
let mut first_condition = true;
|
||||
|
||||
if let Some(pubkey) = filters.pubkey {
|
||||
qb.push(" WHERE pubkey = ");
|
||||
qb.push_bind(pubkey);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(kind) = filters.kind {
|
||||
if first_condition {
|
||||
qb.push(" WHERE ");
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
}
|
||||
qb.push("kind = ");
|
||||
qb.push_bind(kind);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(tag_filters) = filters.tags {
|
||||
if let Some(exact) = tag_filters.exact {
|
||||
if first_condition {
|
||||
qb.push(" WHERE ");
|
||||
} else {
|
||||
qb.push(" AND ");
|
||||
}
|
||||
qb.push("tags @> ");
|
||||
qb.push_bind(exact);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add ordering, limit and offset
|
||||
qb.push(" ORDER BY embedding <-> ")
|
||||
.push_bind(embedding)
|
||||
.push(" LIMIT ")
|
||||
.push_bind(limit)
|
||||
.push(" OFFSET ")
|
||||
.push_bind(offset);
|
||||
|
||||
// Build and execute the query
|
||||
let mut query = qb.build_query_as::<EventWithSimilarity>();
|
||||
|
||||
let results = query
|
||||
.fetch_all(&state.pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
// Build the count query
|
||||
let mut count_qb = QueryBuilder::new("SELECT COUNT(*) FROM nostr_search.events");
|
||||
|
||||
if let Some(filters) = search_query.filters {
|
||||
let mut first_condition = true;
|
||||
|
||||
if let Some(pubkey) = filters.pubkey {
|
||||
count_qb.push(" WHERE pubkey = ");
|
||||
count_qb.push_bind(pubkey);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(kind) = filters.kind {
|
||||
if first_condition {
|
||||
count_qb.push(" WHERE ");
|
||||
} else {
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
count_qb.push("kind = ");
|
||||
count_qb.push_bind(kind);
|
||||
first_condition = false;
|
||||
}
|
||||
|
||||
if let Some(tag_filters) = filters.tags {
|
||||
if let Some(exact) = tag_filters.exact {
|
||||
if first_condition {
|
||||
count_qb.push(" WHERE ");
|
||||
} else {
|
||||
count_qb.push(" AND ");
|
||||
}
|
||||
count_qb.push("tags @> ");
|
||||
count_qb.push_bind(exact);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total: i64 = count_qb
|
||||
.build()
|
||||
.fetch_one(&state.pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.get(0);
|
||||
|
||||
Ok(Json(SearchResult {
|
||||
results,
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_similar_events(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(event_id): Path<String>,
|
||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
) -> Result<Json<Vec<EventWithSimilarity>>, String> {
|
||||
let limit = params
|
||||
.get("limit")
|
||||
.and_then(|l| l.parse::<i64>().ok())
|
||||
.unwrap_or(5);
|
||||
|
||||
let query = sqlx::query_as::<_, EventWithSimilarity>(
|
||||
"WITH event_embedding AS (
|
||||
SELECT embedding
|
||||
FROM nostr_search.events
|
||||
WHERE id = $1
|
||||
)
|
||||
SELECT
|
||||
ne.id,
|
||||
ne.pubkey,
|
||||
ne.created_at,
|
||||
ne.kind,
|
||||
ne.content,
|
||||
ne.tags,
|
||||
1 - (ne.embedding <-> e.embedding) as similarity
|
||||
FROM nostr_search.events ne, event_embedding e
|
||||
WHERE ne.id != $1
|
||||
ORDER BY ne.embedding <-> e.embedding
|
||||
LIMIT $2",
|
||||
)
|
||||
.bind(event_id)
|
||||
.bind(limit);
|
||||
|
||||
let similar_events = query
|
||||
.fetch_all(&state.pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
Ok(Json(similar_events))
|
||||
}
|
||||
|
||||
async fn get_tag_values(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(tag_key): Path<String>,
|
||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
) -> Result<Json<Vec<TagValue>>, String> {
|
||||
let limit = params
|
||||
.get("limit")
|
||||
.and_then(|l| l.parse::<i64>().ok())
|
||||
.unwrap_or(100);
|
||||
|
||||
let query = sqlx::query_as::<_, TagValue>(
|
||||
"SELECT DISTINCT tag->>'value' as value, COUNT(*) as count
|
||||
FROM nostr_search.events,
|
||||
jsonb_array_elements(tags) tag
|
||||
WHERE tag->>'key' = $1
|
||||
GROUP BY tag->>'value'
|
||||
ORDER BY count DESC
|
||||
LIMIT $2",
|
||||
)
|
||||
.bind(tag_key)
|
||||
.bind(limit);
|
||||
|
||||
let values = query
|
||||
.fetch_all(&state.pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
Ok(Json(values))
|
||||
}
|
||||
|
||||
pub async fn generate_embedding(
|
||||
state: &AppState,
|
||||
text: &str,
|
||||
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
content: Content<'a>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Content<'a> {
|
||||
parts: Vec<Part<'a>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Part<'a> {
|
||||
text: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
embedding: Embedding,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Embedding {
|
||||
values: Vec<f32>,
|
||||
}
|
||||
|
||||
let response = state
|
||||
.openai_client
|
||||
.post(format!("https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={}", state.openai_api_key))
|
||||
.json(&EmbeddingRequest {
|
||||
model: "models/text-embedding-004",
|
||||
content: Content {
|
||||
parts: vec![Part { text }],
|
||||
},
|
||||
})
|
||||
.send()
|
||||
.await?
|
||||
.json::<EmbeddingResponse>()
|
||||
.await?;
|
||||
|
||||
Ok(response.embedding.values.clone())
|
||||
}
|
||||
|
||||
pub async fn run_webserver(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.route("/api/search", post(search_events))
|
||||
.route("/api/events/:event_id/similar", get(get_similar_events))
|
||||
.route("/api/tags/:tag_key/values", get(get_tag_values))
|
||||
.with_state(state);
|
||||
|
||||
// Start server
|
||||
println!("listening on port 3000");
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Add table
Reference in a new issue