graph TB
A[Data Collection] --> B[Ontology Engineering]
B --> C[Knowledge Graph Construction]
C --> D[ML Model Training]
D --> E[API Development]
E --> F[Frontend Interface]
F --> G[Deployment]
H[Monitoring & Feedback] --> B
H --> D
style A fill:#f9f,stroke:#333
style G fill:#9cf,stroke:#333
End-to-End Ontology System
Building a Complete Ontology-Driven Application
This guide walks through building a complete, production-ready ontology-driven application for plant disease diagnosis, from data collection to deployment.
System Architecture
Data Pipeline
1. Data Collection
import requests
from pathlib import Path
from typing import List, Dict
import pandas as pd
class PlantDiseaseDataCollector:
def __init__(self, output_dir: str = "data"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def collect_from_plantvillage(self, limit: int = 100) -> pd.DataFrame:
"""Collect plant disease data from PlantVillage dataset."""
import kaggle
from kaggle.api.kaggle_api_extended import KaggleApi
# Download dataset
api = KaggleApi()
api.authenticate()
api.dataset_download_files(
'emmarex/plantdisease',
path=str(self.output_dir),
unzip=True
)
# Process and return metadata
return self._process_plantvillage_data()
def _process_plantvillage_data(self) -> pd.DataFrame:
"""Process downloaded PlantVillage data."""
data = []
for img_path in (self.output_dir / "PlantVillage").rglob("*.JPG"):
parts = img_path.parts
label = parts[-2]
data.append({
'image_path': str(img_path),
'disease': label.split('___')[-1],
'plant': label.split('___')[0],
'source': 'PlantVillage'
})
return pd.DataFrame(data)2. Data Preprocessing
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
class PlantDiseaseDataset(Dataset):
def __init__(self, df, transform=None):
self.df = df
self.transform = transform
self.classes = sorted(df['disease'].unique())
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
# Load image
img = cv2.imread(row['image_path'])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Apply transformations
if self.transform:
img = self.transform(image=img)['image']
# Get label
label = self.class_to_idx[row['disease']]
return {
'image': img,
'label': label,
'plant': row['plant'],
'disease': row['disease']
}
def get_data_loaders(df, batch_size=32, val_split=0.2):
"""Create train and validation data loaders."""
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Split data
train_df, val_df = train_test_split(df, test_size=val_split, stratify=df['disease'])
# Define transforms
train_transform = A.Compose([
A.RandomResizedCrop(224, 224),
A.HorizontalFlip(),
A.VerticalFlip(),
A.RandomRotate90(),
A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
val_transform = A.Compose([
A.Resize(256, 256),
A.CenterCrop(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Create datasets
train_ds = PlantDiseaseDataset(train_df, transform=train_transform)
val_ds = PlantDiseaseDataset(val_df, transform=val_transform)
# Create data loaders
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
)
return train_loader, val_loader, train_ds.classesModel Development
1. Model Architecture
import torch.nn as nn
import timm
class PlantDiseaseModel(nn.Module):
def __init__(self, num_classes, model_name='efficientnet_b0', pretrained=True):
super().__init__()
self.backbone = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=0 # Return features, not logits
)
in_features = self.backbone.num_features
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(in_features, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dropout(0.2),
nn.Linear(512, num_classes)
)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)2. Training Loop
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
class PlantDiseaseClassifier(pl.LightningModule):
def __init__(self, num_classes, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.model = PlantDiseaseModel(num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x = batch['image']
y = batch['label']
logits = self(x)
loss = self.criterion(logits, y)
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x = batch['image']
y = batch['label']
logits = self(x)
loss = self.criterion(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return {'val_loss': loss, 'val_acc': acc}
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.hparams.learning_rate)
scheduler = {
'scheduler': CosineAnnealingLR(
optimizer,
T_max=10,
eta_min=1e-6
),
'monitor': 'val_loss',
'interval': 'epoch',
'frequency': 1
}
return [optimizer], [scheduler]
def train_model(train_loader, val_loader, num_classes, max_epochs=20):
"""Train the plant disease classification model."""
# Define callbacks
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
mode='max',
save_top_k=1,
dirpath='checkpoints/',
filename='plant-disease-{epoch:02d}-{val_acc:.3f}'
)
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=5,
verbose=True,
mode='min'
)
# Initialize model and trainer
model = PlantDiseaseClassifier(num_classes=num_classes)
trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[checkpoint_callback, early_stop_callback],
gpus=1 if torch.cuda.is_available() else 0,
progress_bar_refresh_rate=10,
logger=pl.loggers.TensorBoardLogger('lightning_logs/')
)
# Train the model
trainer.fit(model, train_loader, val_loader)
return model, trainerKnowledge Graph Integration
1. GraphDB Setup
from SPARQLWrapper import SPARQLWrapper, JSON
import pandas as pd
class KnowledgeGraphManager:
def __init__(self, endpoint: str = "http://localhost:7200/repositories/plant-disease"):
self.sparql = SPARQLWrapper(endpoint)
self.sparql.setReturnFormat(JSON)
def query(self, query: str) -> pd.DataFrame:
"""Execute a SPARQL query and return results as a DataFrame."""
self.sparql.setQuery(query)
results = self.sparql.query().convert()
# Convert to DataFrame
rows = []
for result in results["results"]["bindings"]:
row = {}
for var in results["head"]["vars"]:
if var in result:
row[var] = result[var]["value"]
rows.append(row)
return pd.DataFrame(rows)
def get_disease_info(self, disease_uri: str) -> dict:
"""Get detailed information about a disease."""
query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX plant: <http://example.org/plant-ontology#>
SELECT ?property ?value
WHERE {{
<{disease_uri}> ?p ?o .
?p rdfs:label ?property .
BIND(IF(isLiteral(?o), ?o, STR(?o)) AS ?value)
}}
"""
results = self.query(query)
return dict(zip(results['property'], results['value']))API Development
1. FastAPI Application
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import uvicorn
import torch
from PIL import Image
import io
import numpy as np
# Initialize FastAPI app
app = FastAPI(title="Plant Disease Diagnosis API")
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models and components
model = None # Load your trained model here
kg_manager = KnowledgeGraphManager()
class DiagnosisResult(BaseModel):
disease: str
confidence: float
treatment: str
prevention: str
scientific_info: dict
@app.on_event("startup")
async def startup_event():
"""Initialize models and services."""
global model
# Load your trained model here
# model = load_model(...)
@app.post("/diagnose", response_model=DiagnosisResult)
async def diagnose_plant(
image: UploadFile = File(...),
symptoms: Optional[str] = None
):
try:
# Read and preprocess image
contents = await image.read()
img = Image.open(io.BytesIO(contents)).convert('RGB')
# Convert to tensor and add batch dimension
img_tensor = preprocess_image(img).unsqueeze(0)
# Get prediction
with torch.no_grad():
logits = model(img_tensor)
probs = torch.softmax(logits, dim=1)
confidence, pred_idx = torch.max(probs, dim=1)
# Get class name
class_idx = pred_idx.item()
class_name = model.classes[class_idx]
# Get knowledge graph information
disease_info = kg_manager.get_disease_info(f"http://example.org/disease/{class_name}")
return {
'disease': class_name,
'confidence': confidence.item(),
'treatment': disease_info.get('treatment', 'No treatment information available'),
'prevention': disease_info.get('prevention', 'No prevention information available'),
'scientific_info': disease_info
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def preprocess_image(image, size=(224, 224)):
"""Preprocess image for model inference."""
# Resize and normalize
transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)Frontend Development
1. Streamlit Interface
import streamlit as st
import requests
import io
from PIL import Image
import numpy as np
# Page config
st.set_page_config(
page_title="Plant Disease Diagnosis",
page_icon="🌱",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.main {
background-color: #f8f9fa;
}
.stButton>button {
background-color: #28a745;
color: white;
font-weight: bold;
}
.stButton>button:hover {
background-color: #218838;
}
</style>
""", unsafe_allow_html=True)
# App title
st.title("🌱 Plant Disease Diagnosis")
st.markdown("Upload an image of a plant to diagnose potential diseases.")
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# Additional symptoms
st.sidebar.header("Additional Symptoms")
symptoms = st.sidebar.multiselect(
"Select any additional symptoms you observe:",
["Yellowing leaves", "Brown spots", "White powder", "Wilting", "Stunted growth"]
)
if uploaded_file is not None:
# Display the uploaded image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Make prediction when button is clicked
if st.button('Diagnose Plant'):
with st.spinner('Analyzing...'):
# Prepare the request
files = {"image": uploaded_file.getvalue()}
data = {"symptoms": ",".join(symptoms) if symptoms else ""}
try:
# Call the API
response = requests.post(
"http://localhost:8000/diagnose",
files={"image": (uploaded_file.name, uploaded_file.getvalue())},
data={"symptoms": ",".join(symptoms) if symptoms else ""}
)
if response.status_code == 200:
result = response.json()
# Display results
st.success("Analysis Complete!")
col1, col2 = st.columns(2)
with col1:
st.subheader("Diagnosis")
st.metric("Disease", result['disease'])
st.metric("Confidence", f"{result['confidence']*100:.1f}%")
st.subheader("Treatment")
st.info(result['treatment'])
with col2:
st.subheader("Prevention")
st.warning(result['prevention'])
st.subheader("Scientific Information")
for key, value in result['scientific_info'].items():
if key not in ['treatment', 'prevention']:
st.text(f"{key}: {value}")
else:
st.error(f"Error: {response.text}")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
# Add footer
st.markdown("---")
st.markdown("### About")
st.markdown("""
This application uses deep learning and knowledge graphs to diagnose plant diseases.
The system combines computer vision with semantic reasoning for accurate diagnosis.
""")
if __name__ == "__main__":
import os
os.system("streamlit run app.py")Deployment
1. Docker Compose
version: '3.8'
services:
# API Service
api:
build:
context: .
dockerfile: Dockerfile.api
ports:
- "8000:8000"
environment:
- GRAPHDB_ENDPOINT=http://graphdb:7200/repositories/plant-disease
depends_on:
- graphdb
restart: unless-stopped
# Frontend Service
frontend:
build:
context: .
dockerfile: Dockerfile.frontend
ports:
- "8501:8501"
environment:
- API_URL=http://api:8000
depends_on:
- api
restart: unless-stopped
# GraphDB
graphdb:
image: ontotext/graphdb:10.1.0
ports:
- "7200:7200"
volumes:
- graphdb-data:/opt/graphdb/home/data
environment:
- GDB_HEAP_SIZE=2g
restart: unless-stopped
volumes:
graphdb-data:2. Monitoring and Maintenance
1. Logging and Metrics
import logging
from prometheus_client import start_http_server, Counter, Histogram
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Prometheus metrics
REQUESTS = Counter('api_requests_total', 'Total API requests', ['endpoint', 'method'])
LATENCY = Histogram('api_latency_seconds', 'API latency in seconds', ['endpoint'])
ERRORS = Counter('api_errors_total', 'Total API errors', ['endpoint', 'error_type'])
@app.middleware("http")
async def monitor_requests(request: Request, call_next):
"""Middleware to monitor API requests and collect metrics."""
start_time = time.time()
endpoint = request.url.path
method = request.method
try:
REQUESTS.labels(endpoint=endpoint, method=method).inc()
response = await call_next(request)
# Log latency
latency = time.time() - start_time
LATENCY.labels(endpoint=endpoint).observe(latency)
return response
except Exception as e:
error_type = type(e).__name__
ERRORS.labels(endpoint=endpoint, error_type=error_type).inc()
logger.error(f"Error processing request: {str(e)}", exc_info=True)
raiseNext Steps
- Data Augmentation: Expand the dataset with more plant species and diseases
- Model Optimization: Prune and quantize models for edge deployment
- Active Learning: Continuously improve the model with user feedback
- Mobile App: Develop native mobile applications for field use
- Multi-language Support: Add support for multiple languages in the interface