End-to-End Ontology System

Building a Complete Ontology-Driven Application

This guide walks through building a complete, production-ready ontology-driven application for plant disease diagnosis, from data collection to deployment.

System Architecture

graph TB
    A[Data Collection] --> B[Ontology Engineering]
    B --> C[Knowledge Graph Construction]
    C --> D[ML Model Training]
    D --> E[API Development]
    E --> F[Frontend Interface]
    F --> G[Deployment]
    
    H[Monitoring & Feedback] --> B
    H --> D
    
    style A fill:#f9f,stroke:#333
    style G fill:#9cf,stroke:#333

Data Pipeline

1. Data Collection

import requests
from pathlib import Path
from typing import List, Dict
import pandas as pd

class PlantDiseaseDataCollector:
    def __init__(self, output_dir: str = "data"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
    def collect_from_plantvillage(self, limit: int = 100) -> pd.DataFrame:
        """Collect plant disease data from PlantVillage dataset."""
        import kaggle
        from kaggle.api.kaggle_api_extended import KaggleApi
        
        # Download dataset
        api = KaggleApi()
        api.authenticate()
        api.dataset_download_files(
            'emmarex/plantdisease',
            path=str(self.output_dir),
            unzip=True
        )
        
        # Process and return metadata
        return self._process_plantvillage_data()
    
    def _process_plantvillage_data(self) -> pd.DataFrame:
        """Process downloaded PlantVillage data."""
        data = []
        for img_path in (self.output_dir / "PlantVillage").rglob("*.JPG"):
            parts = img_path.parts
            label = parts[-2]
            data.append({
                'image_path': str(img_path),
                'disease': label.split('___')[-1],
                'plant': label.split('___')[0],
                'source': 'PlantVillage'
            })
        return pd.DataFrame(data)

2. Data Preprocessing

import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader

class PlantDiseaseDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        self.classes = sorted(df['disease'].unique())
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img = cv2.imread(row['image_path'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Apply transformations
        if self.transform:
            img = self.transform(image=img)['image']
            
        # Get label
        label = self.class_to_idx[row['disease']]
        
        return {
            'image': img,
            'label': label,
            'plant': row['plant'],
            'disease': row['disease']
        }

def get_data_loaders(df, batch_size=32, val_split=0.2):
    """Create train and validation data loaders."""
    from sklearn.model_selection import train_test_split
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    
    # Split data
    train_df, val_df = train_test_split(df, test_size=val_split, stratify=df['disease'])
    
    # Define transforms
    train_transform = A.Compose([
        A.RandomResizedCrop(224, 224),
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    val_transform = A.Compose([
        A.Resize(256, 256),
        A.CenterCrop(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    # Create datasets
    train_ds = PlantDiseaseDataset(train_df, transform=train_transform)
    val_ds = PlantDiseaseDataset(val_df, transform=val_transform)
    
    # Create data loaders
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
    )
    
    return train_loader, val_loader, train_ds.classes

Model Development

1. Model Architecture

import torch.nn as nn
import timm

class PlantDiseaseModel(nn.Module):
    def __init__(self, num_classes, model_name='efficientnet_b0', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            model_name, 
            pretrained=pretrained,
            num_classes=0  # Return features, not logits
        )
        in_features = self.backbone.num_features
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

2. Training Loop

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

class PlantDiseaseClassifier(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = PlantDiseaseModel(num_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x = batch['image']
        y = batch['label']
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x = batch['image']
        y = batch['label']
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return {'val_loss': loss, 'val_acc': acc}
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = {
            'scheduler': CosineAnnealingLR(
                optimizer, 
                T_max=10,
                eta_min=1e-6
            ),
            'monitor': 'val_loss',
            'interval': 'epoch',
            'frequency': 1
        }
        return [optimizer], [scheduler]

def train_model(train_loader, val_loader, num_classes, max_epochs=20):
    """Train the plant disease classification model."""
    # Define callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',
        mode='max',
        save_top_k=1,
        dirpath='checkpoints/',
        filename='plant-disease-{epoch:02d}-{val_acc:.3f}'
    )
    
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        verbose=True,
        mode='min'
    )
    
    # Initialize model and trainer
    model = PlantDiseaseClassifier(num_classes=num_classes)
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback, early_stop_callback],
        gpus=1 if torch.cuda.is_available() else 0,
        progress_bar_refresh_rate=10,
        logger=pl.loggers.TensorBoardLogger('lightning_logs/')
    )
    
    # Train the model
    trainer.fit(model, train_loader, val_loader)
    
    return model, trainer

Knowledge Graph Integration

1. GraphDB Setup

from SPARQLWrapper import SPARQLWrapper, JSON
import pandas as pd

class KnowledgeGraphManager:
    def __init__(self, endpoint: str = "http://localhost:7200/repositories/plant-disease"):
        self.sparql = SPARQLWrapper(endpoint)
        self.sparql.setReturnFormat(JSON)
    
    def query(self, query: str) -> pd.DataFrame:
        """Execute a SPARQL query and return results as a DataFrame."""
        self.sparql.setQuery(query)
        results = self.sparql.query().convert()
        
        # Convert to DataFrame
        rows = []
        for result in results["results"]["bindings"]:
            row = {}
            for var in results["head"]["vars"]:
                if var in result:
                    row[var] = result[var]["value"]
            rows.append(row)
            
        return pd.DataFrame(rows)
    
    def get_disease_info(self, disease_uri: str) -> dict:
        """Get detailed information about a disease."""
        query = f"""
        PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
        PREFIX plant: <http://example.org/plant-ontology#>
        
        SELECT ?property ?value
        WHERE {{
            <{disease_uri}> ?p ?o .
            ?p rdfs:label ?property .
            BIND(IF(isLiteral(?o), ?o, STR(?o)) AS ?value)
        }}
        """
        results = self.query(query)
        return dict(zip(results['property'], results['value']))

API Development

1. FastAPI Application

from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import uvicorn
import torch
from PIL import Image
import io
import numpy as np

# Initialize FastAPI app
app = FastAPI(title="Plant Disease Diagnosis API")

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load models and components
model = None  # Load your trained model here
kg_manager = KnowledgeGraphManager()

class DiagnosisResult(BaseModel):
    disease: str
    confidence: float
    treatment: str
    prevention: str
    scientific_info: dict

@app.on_event("startup")
async def startup_event():
    """Initialize models and services."""
    global model
    # Load your trained model here
    # model = load_model(...)

@app.post("/diagnose", response_model=DiagnosisResult)
async def diagnose_plant(
    image: UploadFile = File(...),
    symptoms: Optional[str] = None
):
    try:
        # Read and preprocess image
        contents = await image.read()
        img = Image.open(io.BytesIO(contents)).convert('RGB')
        
        # Convert to tensor and add batch dimension
        img_tensor = preprocess_image(img).unsqueeze(0)
        
        # Get prediction
        with torch.no_grad():
            logits = model(img_tensor)
            probs = torch.softmax(logits, dim=1)
            confidence, pred_idx = torch.max(probs, dim=1)
            
        # Get class name
        class_idx = pred_idx.item()
        class_name = model.classes[class_idx]
        
        # Get knowledge graph information
        disease_info = kg_manager.get_disease_info(f"http://example.org/disease/{class_name}")
        
        return {
            'disease': class_name,
            'confidence': confidence.item(),
            'treatment': disease_info.get('treatment', 'No treatment information available'),
            'prevention': disease_info.get('prevention', 'No prevention information available'),
            'scientific_info': disease_info
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

def preprocess_image(image, size=(224, 224)):
    """Preprocess image for model inference."""
    # Resize and normalize
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    return transform(image)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Frontend Development

1. Streamlit Interface

import streamlit as st
import requests
import io
from PIL import Image
import numpy as np

# Page config
st.set_page_config(
    page_title="Plant Disease Diagnosis",
    page_icon="🌱",
    layout="wide"
)

# Custom CSS
st.markdown("""
<style>
    .main {
        background-color: #f8f9fa;
    }
    .stButton>button {
        background-color: #28a745;
        color: white;
        font-weight: bold;
    }
    .stButton>button:hover {
        background-color: #218838;
    }
</style>
""", unsafe_allow_html=True)

# App title
st.title("🌱 Plant Disease Diagnosis")
st.markdown("Upload an image of a plant to diagnose potential diseases.")

# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

# Additional symptoms
st.sidebar.header("Additional Symptoms")
symptoms = st.sidebar.multiselect(
    "Select any additional symptoms you observe:",
    ["Yellowing leaves", "Brown spots", "White powder", "Wilting", "Stunted growth"]
)

if uploaded_file is not None:
    # Display the uploaded image
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)
    
    # Make prediction when button is clicked
    if st.button('Diagnose Plant'):
        with st.spinner('Analyzing...'):
            # Prepare the request
            files = {"image": uploaded_file.getvalue()}
            data = {"symptoms": ",".join(symptoms) if symptoms else ""}
            
            try:
                # Call the API
                response = requests.post(
                    "http://localhost:8000/diagnose",
                    files={"image": (uploaded_file.name, uploaded_file.getvalue())},
                    data={"symptoms": ",".join(symptoms) if symptoms else ""}
                )
                
                if response.status_code == 200:
                    result = response.json()
                    
                    # Display results
                    st.success("Analysis Complete!")
                    
                    col1, col2 = st.columns(2)
                    
                    with col1:
                        st.subheader("Diagnosis")
                        st.metric("Disease", result['disease'])
                        st.metric("Confidence", f"{result['confidence']*100:.1f}%")
                        
                        st.subheader("Treatment")
                        st.info(result['treatment'])
                        
                    with col2:
                        st.subheader("Prevention")
                        st.warning(result['prevention'])
                        
                        st.subheader("Scientific Information")
                        for key, value in result['scientific_info'].items():
                            if key not in ['treatment', 'prevention']:
                                st.text(f"{key}: {value}")
                else:
                    st.error(f"Error: {response.text}")
                    
            except Exception as e:
                st.error(f"An error occurred: {str(e)}")

# Add footer
st.markdown("---")
st.markdown("### About")
st.markdown("""
This application uses deep learning and knowledge graphs to diagnose plant diseases.
The system combines computer vision with semantic reasoning for accurate diagnosis.
""")

if __name__ == "__main__":
    import os
    os.system("streamlit run app.py")

Deployment

1. Docker Compose

version: '3.8'

services:
  # API Service
  api:
    build: 
      context: .
      dockerfile: Dockerfile.api
    ports:
      - "8000:8000"
    environment:
      - GRAPHDB_ENDPOINT=http://graphdb:7200/repositories/plant-disease
    depends_on:
      - graphdb
    restart: unless-stopped

  # Frontend Service
  frontend:
    build:
      context: .
      dockerfile: Dockerfile.frontend
    ports:
      - "8501:8501"
    environment:
      - API_URL=http://api:8000
    depends_on:
      - api
    restart: unless-stopped

  # GraphDB
  graphdb:
    image: ontotext/graphdb:10.1.0
    ports:
      - "7200:7200"
    volumes:
      - graphdb-data:/opt/graphdb/home/data
    environment:
      - GDB_HEAP_SIZE=2g
    restart: unless-stopped

volumes:
  graphdb-data:

2. Monitoring and Maintenance

1. Logging and Metrics

import logging
from prometheus_client import start_http_server, Counter, Histogram
import time

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Prometheus metrics
REQUESTS = Counter('api_requests_total', 'Total API requests', ['endpoint', 'method'])
LATENCY = Histogram('api_latency_seconds', 'API latency in seconds', ['endpoint'])
ERRORS = Counter('api_errors_total', 'Total API errors', ['endpoint', 'error_type'])

@app.middleware("http")
async def monitor_requests(request: Request, call_next):
    """Middleware to monitor API requests and collect metrics."""
    start_time = time.time()
    endpoint = request.url.path
    method = request.method
    
    try:
        REQUESTS.labels(endpoint=endpoint, method=method).inc()
        response = await call_next(request)
        
        # Log latency
        latency = time.time() - start_time
        LATENCY.labels(endpoint=endpoint).observe(latency)
        
        return response
        
    except Exception as e:
        error_type = type(e).__name__
        ERRORS.labels(endpoint=endpoint, error_type=error_type).inc()
        logger.error(f"Error processing request: {str(e)}", exc_info=True)
        raise

Next Steps

  1. Data Augmentation: Expand the dataset with more plant species and diseases
  2. Model Optimization: Prune and quantize models for edge deployment
  3. Active Learning: Continuously improve the model with user feedback
  4. Mobile App: Develop native mobile applications for field use
  5. Multi-language Support: Add support for multiple languages in the interface

References