refactor into multipurpose

This commit is contained in:
Jack Chakany 2025-03-31 09:09:54 -04:00
parent d650a7496f
commit 6578cca96d
8 changed files with 443 additions and 374 deletions

7
Cargo.lock generated
View file

@ -86,6 +86,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "anyhow"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
[[package]]
name = "async-trait"
version = "0.1.86"
@ -1161,6 +1167,7 @@ dependencies = [
name = "noah"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"clap",
"pgvector",

View file

@ -4,9 +4,10 @@ version = "0.1.0"
edition = "2021"
[features]
default = ["postgres", "search"]
postgres = ["sqlx/postgres", "pgvector"]
search = ["postgres"]
default = []
relay = []
discovery = ["pgvector"]
search = ["discovery"]
[dependencies]
clap = { version = "4.3", features = ["derive"] }
@ -15,10 +16,12 @@ tokio = { version = "1.0", features = ["full"] }
sqlx = { version = "0.7", features = [
"runtime-tokio-native-tls",
"json",
], optional = true }
"postgres",
] }
pgvector = { version = "0.3", features = ["sqlx"], optional = true }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
reqwest = { version = "0.11", features = ["json"] }
tracing = "0.1"
tracing-subscriber = "0.3"
anyhow = "1.0.97"

17
src/db.rs Normal file
View file

@ -0,0 +1,17 @@
use sqlx::{postgres::PgPoolOptions, PgPool, migrate::Migrator};
use anyhow::Result;
//static MIGRATOR: Migrator = sqlx::migrate!();
pub type Db = PgPool;
pub async fn initDb(connection_string: &str) -> Result<Db> {
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&connection_string)
.await?;
//MIGRATOR.run(&pool).await?;
Ok(pool)
}

View file

@ -0,0 +1,54 @@
use crate::App;
use anyhow::Result;
use serde::{Deserialize, Serialize};
pub type Embedding = Vec<f32>;
pub mod cloudflare {
#[derive(Serialize)]
struct EmbeddingContext<'a> {
text: &'a str,
}
#[derive(Serialize)]
struct EmbeddingRequest<'a> {
query: Option<&'a str>, // honestly we are probably never going to use this field
contexts: Vec<CloudflareEmbeddingContext<'a>>,
}
#[derive(Deserialize)]
struct EmbeddingResult {
response: Vec<Embedding>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
result: CloudflareEmbeddingResult,
success: bool,
errors: Vec<String>,
messages: Vec<String>,
}
pub async fn generate_embedding(app: &App, text: &str) -> Result<Embedding> {
let response = app
.web_client
.post(format!(
"https://api.cloudflare.com/client/v4/accounts/{}/ai/run/@cf/baai/bge-m3",
app.cloudflare_account_id
))
.header(
"Authorization",
format!("Bearer {}", app.cloudflare_api_key),
)
.json(&EmbeddingRequest {
query: None,
contexts: vec![EmbeddingContext { text: text }],
})
.send()
.await?
.json::<EmbeddingResponse>()
.await?;
Ok(response.result.response[0].clone())
}
}

3
src/discovery/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub(super) mod embedding;
#[cfg(feature = "search")]
pub mod search;

241
src/discovery/search.rs Normal file
View file

@ -0,0 +1,241 @@
use crate::{App, embedding::generate_embedding};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use sqlx::{QueryBuilder, Row};
use pgvector::Vector;
#[derive(Debug, Deserialize, Clone)]
struct TagFilters {
exact: Option<serde_json::Value>,
any: Option<Vec<String>>,
values: Option<TagValueFilter>,
}
#[derive(Debug, Deserialize, Clone)]
struct TagValueFilter {
key: String,
values: Vec<String>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct SearchQuery {
query: String,
limit: Option<i64>,
offset: Option<i64>,
filters: Option<SearchFilters>,
}
#[derive(Debug, Deserialize, Clone)]
struct SearchFilters {
pubkey: Option<String>,
kind: Option<i32>,
tags: Option<TagFilters>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
results: Vec<EventWithSimilarity>,
total: i64,
limit: i64,
offset: i64,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
struct EventWithSimilarity {
id: String,
pubkey: String,
created_at: i64,
kind: i32,
content: String,
tags: serde_json::Value,
similarity: f64,
}
pub async fn search_events(app: &App, search_query: SearchQuery) -> Result<SearchResult> {
let limit = search_query.limit.unwrap_or(10);
let offset = search_query.offset.unwrap_or(0);
let embedding_vec = generate_embedding(&app, &search_query.query).await?;
let embedding = Vector::from(embedding_vec);
// Start building the base query
let mut qb = QueryBuilder::new(
"SELECT id, pubkey, created_at, kind, content, tags, 1 - (embedding <=> ",
);
qb.push_bind(embedding.clone())
.push(") as similarity FROM nostr_search.events");
// Add WHERE clause if we have filters
if let Some(filters) = search_query.clone().filters {
let mut first_condition = true;
if let Some(pubkey) = filters.pubkey {
qb.push(" WHERE pubkey = ");
qb.push_bind(pubkey);
first_condition = false;
}
if let Some(kind) = filters.kind {
if first_condition {
qb.push(" WHERE ");
} else {
qb.push(" AND ");
}
qb.push("kind = ");
qb.push_bind(kind);
first_condition = false;
}
if let Some(tag_filters) = filters.tags {
if let Some(exact) = tag_filters.exact {
if first_condition {
qb.push(" WHERE ");
} else {
qb.push(" AND ");
}
qb.push("tags @> ");
qb.push_bind(exact);
}
}
}
// Add ordering, limit and offset
qb.push(" ORDER BY embedding <=> ")
.push_bind(embedding)
.push(" LIMIT ")
.push_bind(limit)
.push(" OFFSET ")
.push_bind(offset);
// Build and execute the query
let query = qb.build_query_as::<EventWithSimilarity>();
let results = query.fetch_all(&app.db).await?;
// Build the count query
let mut count_qb = QueryBuilder::new("SELECT COUNT(*) FROM nostr_search.events");
if let Some(filters) = search_query.filters {
let mut first_condition = true;
if let Some(pubkey) = filters.pubkey {
count_qb.push(" WHERE pubkey = ");
count_qb.push_bind(pubkey);
first_condition = false;
}
if let Some(kind) = filters.kind {
if first_condition {
count_qb.push(" WHERE ");
} else {
count_qb.push(" AND ");
}
count_qb.push("kind = ");
count_qb.push_bind(kind);
first_condition = false;
}
if let Some(tag_filters) = filters.tags {
if let Some(exact) = tag_filters.exact {
if first_condition {
count_qb.push(" WHERE ");
} else {
count_qb.push(" AND ");
}
count_qb.push("tags @> ");
count_qb.push_bind(exact);
}
}
}
let total: i64 = count_qb
.build()
.fetch_one(&app.db)
.await?
.get(0);
Ok(SearchResult {
results,
total,
limit,
offset,
})
}
//pub async fn get_similar_events(
// State(state): State<Arc<App>>,
// Path(event_id): Path<String>,
// Query(params): Query<std::collections::HashMap<String, String>>,
//) -> Result<Json<Vec<EventWithSimilarity>>, String> {
// let limit = params
// .get("limit")
// .and_then(|l| l.parse::<i64>().ok())
// .unwrap_or(5);
//
// let query = sqlx::query_as::<_, EventWithSimilarity>(
// "WITH event_embedding AS (
// SELECT embedding
// FROM nostr_search.events
// WHERE id = $1
// )
// SELECT
// ne.id,
// ne.pubkey,
// ne.created_at,
// ne.kind,
// ne.content,
// ne.tags,
// 1 - (ne.embedding <=> e.embedding) as similarity
// FROM nostr_search.events ne, event_embedding e
// WHERE ne.id != $1
// ORDER BY ne.embedding <=> e.embedding
// LIMIT $2",
// )
// .bind(event_id)
// .bind(limit);
//
// let similar_events = query
// .fetch_all(&state.pool)
// .await
// .map_err(|e| e.to_string())?;
//
// Ok(Json(similar_events))
//}
//
//pub async fn get_tag_values(
// State(state): State<Arc<App>>,
// Path(tag_key): Path<String>,
// Query(params): Query<std::collections::HashMap<String, String>>,
//) -> Result<Json<Vec<TagValue>>, String> {
// let limit = params
// .get("limit")
// .and_then(|l| l.parse::<i64>().ok())
// .unwrap_or(100);
//
// let query = sqlx::query_as::<_, TagValue>(
// "SELECT DISTINCT tag->>'value' as value, COUNT(*) as count
// FROM nostr_search.events,
// jsonb_array_elements(tags) tag
// WHERE tag->>'key' = $1
// GROUP BY tag->>'value'
// ORDER BY count DESC
// LIMIT $2",
// )
// .bind(tag_key)
// .bind(limit);
//
// let values = query
// .fetch_all(&state.pool)
// .await
// .map_err(|e| e.to_string())?;
//
// Ok(Json(values))
//}
//
//#[derive(Debug, Serialize, sqlx::FromRow)]
//struct TagValue {
// value: String,
// count: i64,
//}

View file

@ -1,14 +1,19 @@
use clap::{Parser, Subcommand};
use anyhow::Result;
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
routing::{post, get},
Json, Router,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::error::Error;
use std::io::{self, BufRead, Write};
use sqlx::migrate::Migrator;
use sqlx::{postgres::PgPoolOptions, PgPool};
use std::sync::Arc;
use tracing_subscriber;
mod relay;
#[cfg(feature = "search")]
mod search;
pub mod db;
#[cfg(feature = "discovery")]
pub mod discovery;
// Types
#[derive(Debug, Serialize, Deserialize)]
@ -18,7 +23,32 @@ struct NostrEvent {
created_at: i64,
kind: i32,
content: String,
tags: Value,
tags: serde_json::Value,
}
// App state
#[derive(Clone)]
pub struct App {
//pub db: db::Db,
pub web_client: reqwest::Client,
pub cloudflare_account_id: Option<String>,
// is this a bad idea? ;)
pub cloudflare_api_key: Option<String>,
}
use clap::{Parser, Subcommand};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Run the daemon
Daemon,
}
#[derive(Debug, Deserialize)]
@ -39,47 +69,31 @@ struct PluginOutput {
msg: Option<String>,
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Run the daemon
Daemon,
}
// App state definition - available regardless of features
#[derive(Clone)]
struct AppState {
#[cfg(feature = "postgres")]
pool: sqlx::PgPool,
#[cfg(feature = "search")]
openai_client: reqwest::Client,
#[cfg(feature = "search")]
openai_api_key: String,
}
use std::io::{self, BufRead, Write};
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() -> Result<()> {
// Initialize tracing for logging
tracing_subscriber::fmt::init();
#[cfg(feature = "postgres")]
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect(&std::env::var("POSTGRES_CONNECTION")?)
.await?;
//let pool = db::initDb(&std::env::var("POSTGRES_CONNECTION")?).await?;
let cf_acc_id = match std::env::var("CLOUDFLARE_ACCOUNT_ID") {
Ok(v) => Some(v),
Err(e) => None,
};
let cf_api_key = match std::env::var("CLOUDFLARE_API_KEY") {
Ok(v) => Some(v),
Err(e) => None,
};
// Initialize state
let state = Arc::new(AppState {
#[cfg(feature = "postgres")]
pool,
openai_client: reqwest::Client::new(),
openai_api_key: std::env::var("GEMINI_API_KEY")?,
let state = Arc::new(App {
//db: pool,
web_client: reqwest::Client::new(),
cloudflare_account_id: cf_acc_id,
cloudflare_api_key: cf_api_key,
});
let cli = Cli::parse();
@ -87,13 +101,65 @@ async fn main() -> Result<(), Box<dyn Error>> {
match cli.command {
Commands::Daemon => {
println!("Starting daemon");
#[cfg(feature = "search")]
search::run_webserver(state).await?;
#[cfg(not(feature = "search"))]
println!("Search feature not enabled, webserver functionality unavailable");
run_webserver(state).await;
}
}
Ok(())
}
async fn home_path(
State(state): State<Arc<App>>
) -> Result<&'static str, AppError> {
Ok("Hello World")
}
async fn run_webserver(state: Arc<App>) -> Result<()> {
println!("Welcome to noah");
println!("from Chakany Systems");
// Create router
let app = Router::new()
.route("/", get(home_path));
//.route("/api/events/:event_id/similar", get(get_similar_events))
//.route("/api/tags/:tag_key/values", get(get_tag_values))
#[cfg(feature = "search")]
let app = app.route("/api/search", post(api_search_events));
let app = app.with_state(state);
// Start server
println!("listening on port 3000");
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
Ok(())
}
struct AppError(anyhow::Error);
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}
#[cfg(feature = "search")]
async fn api_search_events(
State(state): State<Arc<App>>,
Json(search_query): Json<queries::SearchQuery>,
) -> Result<Json<queries::SearchResult>, AppError> {
Ok(Json(queries::search_events(&state, search_query).await?))
}

View file

@ -1,322 +0,0 @@
use axum::{
extract::{Path, Query, State},
routing::{get, post},
Json, Router,
};
use pgvector::Vector;
use serde::{Deserialize, Serialize};
use sqlx::{postgres::PgPoolOptions, PgPool};
use sqlx::{QueryBuilder, Row};
use std::sync::Arc;
use crate::AppState;
// Search-specific types
#[derive(Debug, Deserialize, Clone)]
pub struct SearchQuery {
pub query: String,
pub limit: Option<i64>,
pub offset: Option<i64>,
pub filters: Option<SearchFilters>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct SearchFilters {
pub pubkey: Option<String>,
pub kind: Option<i32>,
pub tags: Option<TagFilters>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct TagFilters {
pub exact: Option<serde_json::Value>,
pub any: Option<Vec<String>>,
pub values: Option<TagValueFilter>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct TagValueFilter {
pub key: String,
pub values: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub results: Vec<EventWithSimilarity>,
pub total: i64,
pub limit: i64,
pub offset: i64,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct EventWithSimilarity {
pub id: String,
pub pubkey: String,
pub created_at: i64,
pub kind: i32,
pub content: String,
pub tags: serde_json::Value,
pub similarity: f64,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct TagValue {
pub value: String,
pub count: i64,
}
// Handlers
async fn search_events(
State(state): State<Arc<AppState>>,
Json(search_query): Json<SearchQuery>,
) -> Result<Json<SearchResult>, String> {
let limit = search_query.limit.unwrap_or(10);
let offset = search_query.offset.unwrap_or(0);
let embedding_vec = generate_embedding(&state, &search_query.query)
.await
.map_err(|e| e.to_string())?;
let embedding = Vector::from(embedding_vec);
// Start building the base query
let mut qb = QueryBuilder::new(
"SELECT id, pubkey, created_at, kind, content, tags, 1 - (embedding <-> ",
);
qb.push_bind(embedding.clone())
.push(") as similarity FROM nostr_search.events");
// Add WHERE clause if we have filters
if let Some(filters) = search_query.clone().filters {
let mut first_condition = true;
if let Some(pubkey) = filters.pubkey {
qb.push(" WHERE pubkey = ");
qb.push_bind(pubkey);
first_condition = false;
}
if let Some(kind) = filters.kind {
if first_condition {
qb.push(" WHERE ");
} else {
qb.push(" AND ");
}
qb.push("kind = ");
qb.push_bind(kind);
first_condition = false;
}
if let Some(tag_filters) = filters.tags {
if let Some(exact) = tag_filters.exact {
if first_condition {
qb.push(" WHERE ");
} else {
qb.push(" AND ");
}
qb.push("tags @> ");
qb.push_bind(exact);
}
}
}
// Add ordering, limit and offset
qb.push(" ORDER BY embedding <-> ")
.push_bind(embedding)
.push(" LIMIT ")
.push_bind(limit)
.push(" OFFSET ")
.push_bind(offset);
// Build and execute the query
let mut query = qb.build_query_as::<EventWithSimilarity>();
let results = query
.fetch_all(&state.pool)
.await
.map_err(|e| e.to_string())?;
// Build the count query
let mut count_qb = QueryBuilder::new("SELECT COUNT(*) FROM nostr_search.events");
if let Some(filters) = search_query.filters {
let mut first_condition = true;
if let Some(pubkey) = filters.pubkey {
count_qb.push(" WHERE pubkey = ");
count_qb.push_bind(pubkey);
first_condition = false;
}
if let Some(kind) = filters.kind {
if first_condition {
count_qb.push(" WHERE ");
} else {
count_qb.push(" AND ");
}
count_qb.push("kind = ");
count_qb.push_bind(kind);
first_condition = false;
}
if let Some(tag_filters) = filters.tags {
if let Some(exact) = tag_filters.exact {
if first_condition {
count_qb.push(" WHERE ");
} else {
count_qb.push(" AND ");
}
count_qb.push("tags @> ");
count_qb.push_bind(exact);
}
}
}
let total: i64 = count_qb
.build()
.fetch_one(&state.pool)
.await
.map_err(|e| e.to_string())?
.get(0);
Ok(Json(SearchResult {
results,
total,
limit,
offset,
}))
}
async fn get_similar_events(
State(state): State<Arc<AppState>>,
Path(event_id): Path<String>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> Result<Json<Vec<EventWithSimilarity>>, String> {
let limit = params
.get("limit")
.and_then(|l| l.parse::<i64>().ok())
.unwrap_or(5);
let query = sqlx::query_as::<_, EventWithSimilarity>(
"WITH event_embedding AS (
SELECT embedding
FROM nostr_search.events
WHERE id = $1
)
SELECT
ne.id,
ne.pubkey,
ne.created_at,
ne.kind,
ne.content,
ne.tags,
1 - (ne.embedding <-> e.embedding) as similarity
FROM nostr_search.events ne, event_embedding e
WHERE ne.id != $1
ORDER BY ne.embedding <-> e.embedding
LIMIT $2",
)
.bind(event_id)
.bind(limit);
let similar_events = query
.fetch_all(&state.pool)
.await
.map_err(|e| e.to_string())?;
Ok(Json(similar_events))
}
async fn get_tag_values(
State(state): State<Arc<AppState>>,
Path(tag_key): Path<String>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> Result<Json<Vec<TagValue>>, String> {
let limit = params
.get("limit")
.and_then(|l| l.parse::<i64>().ok())
.unwrap_or(100);
let query = sqlx::query_as::<_, TagValue>(
"SELECT DISTINCT tag->>'value' as value, COUNT(*) as count
FROM nostr_search.events,
jsonb_array_elements(tags) tag
WHERE tag->>'key' = $1
GROUP BY tag->>'value'
ORDER BY count DESC
LIMIT $2",
)
.bind(tag_key)
.bind(limit);
let values = query
.fetch_all(&state.pool)
.await
.map_err(|e| e.to_string())?;
Ok(Json(values))
}
pub async fn generate_embedding(
state: &AppState,
text: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
#[derive(Serialize)]
struct EmbeddingRequest<'a> {
model: &'static str,
content: Content<'a>,
}
#[derive(Serialize)]
struct Content<'a> {
parts: Vec<Part<'a>>,
}
#[derive(Serialize)]
struct Part<'a> {
text: &'a str,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
embedding: Embedding,
}
#[derive(Deserialize)]
struct Embedding {
values: Vec<f32>,
}
let response = state
.openai_client
.post(format!("https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={}", state.openai_api_key))
.json(&EmbeddingRequest {
model: "models/text-embedding-004",
content: Content {
parts: vec![Part { text }],
},
})
.send()
.await?
.json::<EmbeddingResponse>()
.await?;
Ok(response.embedding.values.clone())
}
pub async fn run_webserver(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Error>> {
// Create router
let app = Router::new()
.route("/api/search", post(search_events))
.route("/api/events/:event_id/similar", get(get_similar_events))
.route("/api/tags/:tag_key/values", get(get_tag_values))
.with_state(state);
// Start server
println!("listening on port 3000");
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
Ok(())
}