from datetime import datetime, timedelta
from typing import Dict, List, Optional
from sqlalchemy import func, and_, or_, desc, case
from sqlalchemy.orm import joinedload
import pandas as pd
import io
import csv

from src.models import DatabaseContextManager
from src.models.models import (
    Student,
    Course,
    Enrollment,
    Attendance,
    AssignmentSubmission,
    Assignment,
    TeachingSession,
    CourseProgress,
    AcademicSession
)
from src.utils import custom_response

class EnrollmentReportsManager:
    def __init__(self):
        pass

    def get_enrollment_analytics(self, period: str = '30d', start_date: str = None, 
                                end_date: str = None, department: str = None, 
                                course_id: str = None) -> Dict:
        """
        Get comprehensive enrollment analytics with student distribution and performance
        """
        try:
            with DatabaseContextManager() as ctx:
                # Calculate date range
                date_range = self._calculate_date_range(period, start_date, end_date)
                
                # Get enrollment distribution data
                enrollment_distribution = self._get_enrollment_distribution(ctx, date_range, department, course_id)
                
                # Get student performance by course
                student_performance = self._get_student_performance_by_course(ctx, date_range, department, course_id)
                
                # Get enrollment trends
                enrollment_trends = self._get_enrollment_trends(ctx, date_range)
                
                # Get overall statistics
                overall_stats = self._calculate_enrollment_overall_stats(ctx, enrollment_distribution, student_performance)
                
                return custom_response(
                    success=True,
                    data={
                        'overall_stats': overall_stats,
                        'enrollment_distribution': enrollment_distribution,
                        'student_performance': student_performance,
                        'enrollment_trends': enrollment_trends,
                        'date_range': {
                            'start': date_range['start'].isoformat(),
                            'end': date_range['end'].isoformat(),
                            'period': period
                        },
                        'filters': {
                            'department': department,
                            'course_id': course_id
                        },
                        'message': "Enrollment analytics retrieved successfully"
                    }
                )
                
        except Exception as e:
            return custom_response(
                success=False,
                data={'error': f"Error retrieving enrollment analytics: {str(e)}"},
                status_code=500
            )

    def _get_enrollment_distribution(self, ctx, date_range: Dict, department: str = None, 
                                   course_id: str = None) -> Dict:
        """
        Get enrollment distribution across courses, departments, and demographics
        """
        try:
            # Base query for enrollments
            query = ctx.session.query(Enrollment).options(
                joinedload(Enrollment.student),
                joinedload(Enrollment.course)
            )
            
            # Apply filters
            if course_id:
                query = query.filter(Enrollment.course_id == course_id)
            
            if department:
                query = query.join(Course).filter(Course.department == department)
            
            enrollments = query.all()
            
            # Course-wise distribution
            course_distribution = {}
            department_distribution = {}
            gender_distribution = {'Male': 0, 'Female': 0, 'Other': 0}
            status_distribution = {'Active': 0, 'Completed': 0, 'Dropped': 0, 'Suspended': 0}
            
            for enrollment in enrollments:
                course = enrollment.course
                student = enrollment.student
                
                # Course distribution
                if course.code not in course_distribution:
                    course_distribution[course.code] = {
                        'course_name': course.title,
                        'enrollment_count': 0,
                        'department': course.department,
                        'credits': course.credits or 0,
                        'max_students': course.max_students or 0
                    }
                course_distribution[course.code]['enrollment_count'] += 1
                
                # Department distribution
                if course.department not in department_distribution:
                    department_distribution[course.department] = 0
                department_distribution[course.department] += 1
                
                # Gender distribution
                if student.gender:
                    if student.gender.lower() in ['male', 'm']:
                        gender_distribution['Male'] += 1
                    elif student.gender.lower() in ['female', 'f']:
                        gender_distribution['Female'] += 1
                    else:
                        gender_distribution['Other'] += 1
                
                # Status distribution (mock based on enrollment status)
                if enrollment.enrollment_status == 'active':
                    status_distribution['Active'] += 1
                elif enrollment.enrollment_status == 'completed':
                    status_distribution['Completed'] += 1
                elif enrollment.enrollment_status == 'dropped':
                    status_distribution['Dropped'] += 1
                else:
                    status_distribution['Suspended'] += 1
            
            # Convert to list format for frontend
            course_distribution_list = [
                {
                    'course_code': code,
                    'course_name': data['course_name'],
                    'department': data['department'],
                    'enrollment_count': data['enrollment_count'],
                    'credits': data['credits'],
                    'max_students': data['max_students'],
                    'enrollment_rate': round((data['enrollment_count'] / data['max_students'] * 100) if data['max_students'] > 0 else 0, 2)
                }
                for code, data in course_distribution.items()
            ]
            
            # Sort by enrollment count
            course_distribution_list.sort(key=lambda x: x['enrollment_count'], reverse=True)
            
            return {
                'by_course': course_distribution_list,
                'by_department': [
                    {'department': dept, 'count': count}
                    for dept, count in department_distribution.items()
                ],
                'by_gender': [
                    {'gender': gender, 'count': count}
                    for gender, count in gender_distribution.items() if count > 0
                ],
                'by_status': [
                    {'status': status, 'count': count}
                    for status, count in status_distribution.items() if count > 0
                ],
                'total_enrollments': len(enrollments)
            }
            
        except Exception as e:
            raise Exception(f"Error calculating enrollment distribution: {str(e)}")

    def _get_student_performance_by_course(self, ctx, date_range: Dict, department: str = None, 
                                         course_id: str = None) -> List[Dict]:
        """
        Get student performance metrics by course
        """
        try:
            # Base query for enrollments with performance data
            query = ctx.session.query(Enrollment).options(
                joinedload(Enrollment.student),
                joinedload(Enrollment.course)
            )
            
            # Apply filters
            if course_id:
                query = query.filter(Enrollment.course_id == course_id)
            
            if department:
                query = query.join(Course).filter(Course.department == department)
            
            enrollments = query.all()
            
            course_performance = {}
            
            for enrollment in enrollments:
                course = enrollment.course
                student = enrollment.student
                
                if course.id not in course_performance:
                    course_performance[course.id] = {
                        'course_code': course.code,
                        'course_name': course.title,
                        'department': course.department,
                        'total_students': 0,
                        'attendance_rate': 0.0,
                        'assignment_completion_rate': 0.0,
                        'average_grade': 0.0,
                        'active_students': 0,
                        'at_risk_students': 0,
                        'excellent_students': 0
                    }
                
                course_data = course_performance[course.id]
                course_data['total_students'] += 1
                
                # Calculate attendance rate for this student
                attendance_rate = self._calculate_student_attendance_rate(ctx, student.id, course.id, date_range)
                
                # Calculate assignment completion rate
                completion_rate = self._calculate_assignment_completion_rate(ctx, student.id, course.id, date_range)
                
                # Calculate average grade (mock calculation)
                average_grade = self._calculate_average_grade(ctx, student.id, course.id)
                
                # Categorize student performance
                if average_grade >= 4.0 and attendance_rate >= 90 and completion_rate >= 90:
                    course_data['excellent_students'] += 1
                elif average_grade >= 2.0 and attendance_rate >= 70:
                    course_data['active_students'] += 1
                else:
                    course_data['at_risk_students'] += 1
                
                # Update course averages
                course_data['attendance_rate'] += attendance_rate
                course_data['assignment_completion_rate'] += completion_rate
                course_data['average_grade'] += average_grade
            
            # Calculate averages
            for course_id, data in course_performance.items():
                if data['total_students'] > 0:
                    data['attendance_rate'] = round(data['attendance_rate'] / data['total_students'], 2)
                    data['assignment_completion_rate'] = round(data['assignment_completion_rate'] / data['total_students'], 2)
                    data['average_grade'] = round(data['average_grade'] / data['total_students'], 2)
                    
                    # Calculate performance score
                    data['performance_score'] = round(
                        (data['attendance_rate'] * 0.3) + 
                        (data['assignment_completion_rate'] * 0.4) + 
                        ((data['average_grade'] / 5.0) * 100 * 0.3), 2
                    )
            
            # Convert to list and sort by performance score
            performance_list = list(course_performance.values())
            performance_list.sort(key=lambda x: x['performance_score'], reverse=True)
            
            return performance_list
            
        except Exception as e:
            raise Exception(f"Error calculating student performance by course: {str(e)}")

    def _calculate_student_attendance_rate(self, ctx, student_id: str, course_id: str, date_range: Dict) -> float:
        """
        Calculate attendance rate for a student in a specific course
        """
        try:
            # Get all sessions for the course in date range
            total_sessions = ctx.session.query(TeachingSession).filter(
                and_(
                    TeachingSession.course_id == course_id,
                    TeachingSession.created_at >= date_range['start'],
                    TeachingSession.created_at <= date_range['end']
                )
            ).count()
            
            if total_sessions == 0:
                return 0.0
            
            # Get attended sessions
            attended_sessions = ctx.session.query(Attendance).join(TeachingSession).filter(
                and_(
                    Attendance.student_id == student_id,
                    TeachingSession.course_id == course_id,
                    TeachingSession.created_at >= date_range['start'],
                    TeachingSession.created_at <= date_range['end'],
                    Attendance.status == 'present'
                )
            ).count()
            
            return round((attended_sessions / total_sessions) * 100, 2)
            
        except Exception:
            return 0.0

    def _calculate_assignment_completion_rate(self, ctx, student_id: str, course_id: str, date_range: Dict) -> float:
        """
        Calculate assignment completion rate for a student in a specific course
        """
        try:
            # Get total assignments for the course
            total_assignments = ctx.session.query(Assignment).filter(
                Assignment.course_id == course_id
            ).count()
            
            if total_assignments == 0:
                return 0.0
            
            # Get completed assignments
            completed_assignments = ctx.session.query(AssignmentSubmission).join(Assignment).filter(
                and_(
                    AssignmentSubmission.student_id == student_id,
                    Assignment.course_id == course_id,
                    AssignmentSubmission.status == 'submitted'
                )
            ).count()
            
            return round((completed_assignments / total_assignments) * 100, 2)
            
        except Exception:
            return 0.0

    def _calculate_average_grade(self, ctx, student_id: str, course_id: str) -> float:
        """
        Calculate average grade for a student in a specific course (mock calculation)
        """
        try:
            # Mock grade calculation based on student and course IDs
            # In real implementation, this would come from actual grade records
            base_grade = 3.0 + (hash(f"{student_id}{course_id}") % 200) / 100  # 3.0-5.0 range
            return round(base_grade, 2)
            
        except Exception:
            return 3.0

    def _get_enrollment_trends(self, ctx, date_range: Dict) -> List[Dict]:
        """
        Get enrollment trends over time
        """
        try:
            trends = []
            current_date = date_range['start']
            
            # Get weekly trends
            while current_date <= date_range['end']:
                week_start = current_date
                week_end = current_date + timedelta(days=7)
                
                # Get enrollments in this week
                weekly_enrollments = ctx.session.query(Enrollment).filter(
                    and_(
                        Enrollment.enrollment_date >= week_start,
                        Enrollment.enrollment_date <= week_end
                    )
                ).count()
                
                # Get active students in this week
                active_students = ctx.session.query(Student).join(Attendance).join(TeachingSession).filter(
                    and_(
                        TeachingSession.created_at >= week_start,
                        TeachingSession.created_at <= week_end
                    )
                ).distinct().count()
                
                trends.append({
                    'week': week_start.isoformat(),
                    'new_enrollments': weekly_enrollments,
                    'active_students': active_students,
                    'total_enrollments': ctx.session.query(Enrollment).filter(
                        Enrollment.enrollment_date <= week_end
                    ).count()
                })
                
                current_date += timedelta(days=7)
            
            return trends
            
        except Exception as e:
            return []

    def _calculate_enrollment_overall_stats(self, ctx, enrollment_distribution: Dict, 
                                          student_performance: List[Dict]) -> Dict:
        """
        Calculate overall enrollment statistics
        """
        try:
            total_enrollments = enrollment_distribution['total_enrollments']
            
            # Calculate department diversity
            departments = len(enrollment_distribution['by_department'])
            
            # Calculate course diversity
            courses = len(enrollment_distribution['by_course'])
            
            # Calculate average performance
            avg_performance = 0
            if student_performance:
                total_performance = sum(course['performance_score'] for course in student_performance)
                avg_performance = round(total_performance / len(student_performance), 2)
            
            # Calculate gender distribution
            gender_data = enrollment_distribution['by_gender']
            total_gender = sum(g['count'] for g in gender_data)
            male_percentage = 0
            female_percentage = 0
            
            for gender_info in gender_data:
                if gender_info['gender'] == 'Male':
                    male_percentage = round((gender_info['count'] / total_gender) * 100, 1) if total_gender > 0 else 0
                elif gender_info['gender'] == 'Female':
                    female_percentage = round((gender_info['count'] / total_gender) * 100, 1) if total_gender > 0 else 0
            
            # Calculate top performing course
            top_course = student_performance[0] if student_performance else None
            
            return {
                'total_enrollments': total_enrollments,
                'departments_count': departments,
                'courses_count': courses,
                'average_performance': avg_performance,
                'male_percentage': male_percentage,
                'female_percentage': female_percentage,
                'top_performing_course': top_course['course_name'] if top_course else 'N/A',
                'top_course_performance': top_course['performance_score'] if top_course else 0,
                'enrollment_growth_rate': 12.5,  # Mock growth rate
                'completion_rate': 87.3  # Mock completion rate
            }
            
        except Exception as e:
            return {
                'total_enrollments': 0,
                'departments_count': 0,
                'courses_count': 0,
                'average_performance': 0,
                'male_percentage': 0,
                'female_percentage': 0,
                'top_performing_course': 'N/A',
                'top_course_performance': 0,
                'enrollment_growth_rate': 0,
                'completion_rate': 0
            }

    def _calculate_date_range(self, period: str, start_date: str = None, end_date: str = None) -> Dict:
        """
        Calculate date range based on period or custom dates
        """
        if start_date and end_date:
            return {
                'start': datetime.fromisoformat(start_date),
                'end': datetime.fromisoformat(end_date)
            }
        
        end = datetime.now()
        
        if period == '7d':
            start = end - timedelta(days=7)
        elif period == '30d':
            start = end - timedelta(days=30)
        elif period == '90d':
            start = end - timedelta(days=90)
        elif period == '1y':
            start = end - timedelta(days=365)
        else:
            start = end - timedelta(days=30)  # Default to 30 days
        
        return {
            'start': start,
            'end': end
        }

    def export_enrollment_data(self, period: str = '30d', format_type: str = 'csv') -> Dict:
        """
        Export enrollment data to CSV or Excel format
        """
        try:
            with DatabaseContextManager() as ctx:
                date_range = self._calculate_date_range(period)
                
                # Get enrollment data
                enrollment_distribution = self._get_enrollment_distribution(ctx, date_range)
                student_performance = self._get_student_performance_by_course(ctx, date_range)
                
                if format_type == 'csv':
                    return self._export_to_csv(enrollment_distribution, student_performance, period)
                else:
                    return self._export_to_excel(enrollment_distribution, student_performance, period)
                    
        except Exception as e:
            return custom_response(
                success=False,
                data={'error': f"Error exporting enrollment data: {str(e)}"},
                status_code=500
            )

    def _export_to_csv(self, enrollment_distribution: Dict, student_performance: List[Dict], period: str) -> Dict:
        """
        Export enrollment data to CSV format
        """
        try:
            # Create CSV content for course performance
            output = io.StringIO()
            fieldnames = ['course_code', 'course_name', 'department', 'total_students', 
                         'attendance_rate', 'assignment_completion_rate', 'average_grade', 
                         'performance_score', 'active_students', 'at_risk_students', 'excellent_students']
            
            writer = csv.DictWriter(output, fieldnames=fieldnames)
            writer.writeheader()
            
            for course in student_performance:
                writer.writerow(course)
            
            csv_content = output.getvalue()
            output.close()
            
            filename = f"enrollment_performance_{period}_{datetime.now().strftime('%Y%m%d')}.csv"
            
            return custom_response(
                success=True,
                data={
                    'content': csv_content,
                    'filename': filename,
                    'content_type': 'text/csv',
                    'message': "Enrollment data exported successfully"
                }
            )
            
        except Exception as e:
            return custom_response(
                success=False,
                data={'error': f"Error creating CSV export: {str(e)}"},
                status_code=500
            )

    def _export_to_excel(self, enrollment_distribution: Dict, student_performance: List[Dict], period: str) -> Dict:
        """
        Export enrollment data to Excel format
        """
        try:
            # Create DataFrame
            df = pd.DataFrame(student_performance)
            
            # Create Excel file in memory
            output = io.BytesIO()
            with pd.ExcelWriter(output, engine='openpyxl') as writer:
                df.to_excel(writer, sheet_name='Course Performance', index=False)
            
            excel_content = output.getvalue()
            output.close()
            
            filename = f"enrollment_performance_{period}_{datetime.now().strftime('%Y%m%d')}.xlsx"
            
            return custom_response(
                success=True,
                data={
                    'content': excel_content.decode('latin-1'),
                    'filename': filename,
                    'content_type': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
                    'message': "Enrollment data exported successfully"
                }
            )
            
        except Exception as e:
            return custom_response(
                success=False,
                data={'error': f"Error creating Excel export: {str(e)}"},
                status_code=500
            )

