chrowatch/backend/src/database.py

103 lines
2.9 KiB
Python
Raw Normal View History

2024-07-23 16:58:58 -04:00
import sqlite3
import json
import numpy as np
import os
DATABASE_PATH = 'videos.db'
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
def get_db_connection():
conn = sqlite3.connect(DATABASE_PATH)
conn.row_factory = dict_factory
return conn
def init_db():
conn = get_db_connection()
c = conn.cursor()
# Check if the videos table exists
c.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='videos'")
table_exists = c.fetchone()
if table_exists:
# If the table exists, check if the filepath column exists
c.execute("PRAGMA table_info(videos)")
columns = [column['name'] for column in c.fetchall()]
if 'filepath' not in columns:
# Add the filepath column if it doesn't exist
c.execute("ALTER TABLE videos ADD COLUMN filepath TEXT")
else:
# If the table doesn't exist, create it with all necessary columns
c.execute('''CREATE TABLE videos
(id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT NOT NULL,
filepath TEXT NOT NULL,
embedding TEXT)''')
conn.commit()
conn.close()
def insert_video(filename, filepath):
conn = get_db_connection()
c = conn.cursor()
c.execute("INSERT INTO videos (filename, filepath) VALUES (?, ?)",
(filename, filepath))
video_id = c.lastrowid
conn.commit()
conn.close()
return video_id
def update_video_embedding(video_id, embedding):
conn = get_db_connection()
c = conn.cursor()
embedding_json = json.dumps(embedding.flatten().tolist())
c.execute("UPDATE videos SET embedding = ? WHERE id = ?",
(embedding_json, video_id))
conn.commit()
conn.close()
def get_all_videos():
conn = get_db_connection()
c = conn.cursor()
c.execute("SELECT * FROM videos")
videos = c.fetchall()
conn.close()
for video in videos:
if video['embedding']:
video['embedding'] = np.array(json.loads(video['embedding'])).reshape(1, -1)
return videos
def get_video_by_id(video_id):
conn = get_db_connection()
c = conn.cursor()
c.execute("SELECT * FROM videos WHERE id = ?", (video_id,))
video = c.fetchone()
conn.close()
if video and video['embedding']:
video['embedding'] = np.array(json.loads(video['embedding']))
return video
def get_vectorized_videos():
conn = get_db_connection()
c = conn.cursor()
c.execute("SELECT * FROM videos WHERE embedding IS NOT NULL")
videos = c.fetchall()
conn.close()
for video in videos:
video['embedding'] = np.array(json.loads(video['embedding']))
return videos
# Initialize the database when this module is imported
init_db()