admin.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from django.contrib import admin
  2. from django.utils.safestring import mark_safe
  3. from django.urls import path
  4. from .views import eval_download_results, config_download_results
  5. from .models import (
  6. QA, Dataset, EvalAnswer,
  7. LLMBackend, LLMModel,
  8. EvalConfig, RoleMessage, EvalSession,
  9. AnswerInterpreter
  10. )
  11. class QAAdmin(admin.ModelAdmin):
  12. list_display = ('question', 'xid', 'dataset','category', 'correct_answer', 'target')
  13. list_filter = ('target', 'dataset', 'category')
  14. search_fields = ('question', 'correct_answer', 'category', 'extra_info')
  15. fieldsets = (
  16. (None, {'fields': ('dataset', 'question', 'category', 'extra_info', 'context', 'options', 'correct_answer', 'correct_answer_idx', 'target')}),
  17. )
  18. readonly_fields = ('hash',)
  19. admin.site.register(QA, QAAdmin)
  20. class DatasetAdmin(admin.ModelAdmin):
  21. list_display = ('name', 'description')
  22. search_fields = ('name', 'description')
  23. ordering = ('-created_at',)
  24. fieldsets = (
  25. (None, {'fields': ('name', 'description')}),
  26. )
  27. admin.site.register(Dataset, DatasetAdmin)
  28. class EvalAnswerAdmin(admin.ModelAdmin):
  29. list_display = ('question', 'get_question_id', 'is_correct', 'llm_model', 'eval_session')
  30. list_filter = ('is_correct', 'llm_backend', 'llm_model', 'question__dataset', 'eval_session', 'eval_session__config')
  31. search_fields = ('question', 'instruction', 'assistant_answer')
  32. ordering = ('-created_at','question')
  33. def get_question_id(self, obj):
  34. return obj.question.id
  35. get_question_id.short_description = 'Question ID'
  36. get_question_id.admin_order_field = 'question__id'
  37. admin.site.register(EvalAnswer, EvalAnswerAdmin)
  38. class LLMBackendAdmin(admin.ModelAdmin):
  39. list_display = ('name',)
  40. search_fields = ('name',)
  41. admin.site.register(LLMBackend, LLMBackendAdmin)
  42. class LLMModelAdmin(admin.ModelAdmin):
  43. list_display = ('name', 'backend')
  44. list_filter = ('backend',)
  45. search_fields = ('name',)
  46. admin.site.register(LLMModel, LLMModelAdmin)
  47. class RoleMessageAdmin(admin.ModelAdmin):
  48. list_display = ('role', 'eval_config')
  49. list_filter = ('role', 'eval_config')
  50. search_fields = ('role', 'content')
  51. admin.site.register(RoleMessage, RoleMessageAdmin)
  52. class RoleMessageInline(admin.TabularInline):
  53. model = RoleMessage
  54. extra = 3
  55. class EvalConfigAdmin(admin.ModelAdmin):
  56. list_display = ('name', 'dataset', 'created_at', 'link')
  57. search_fields = ('name', 'description')
  58. ordering = ('-created_at',)
  59. inlines = [RoleMessageInline]
  60. @admin.display(description='Link')
  61. def link(self, obj):
  62. return mark_safe(f'<a href="/commons/evalconfig/{obj.id}/config_download_results/">Download Results</a>')
  63. def get_urls(self):
  64. urls = super().get_urls()
  65. my_urls = [path('<int:config_id>/config_download_results/', config_download_results, name='config_download_results'),]
  66. return my_urls + urls
  67. admin.site.register(EvalConfig, EvalConfigAdmin)
  68. class EvalSessionAdmin(admin.ModelAdmin):
  69. list_display = ('id', 'name','is_active', 'config', 'llm_model', 'progress', 'accuracy', 'link')
  70. list_filter = ('is_active','config', 'llm_model')
  71. list_display_links = ["name"]
  72. search_fields = ('name', 'config', 'llm_model')
  73. ordering = ('-created_at',)
  74. @admin.display(description='Link')
  75. def link(self, obj):
  76. return mark_safe(f'<a href="/commons/evalsession/{obj.id}/eval_download_results/">Download Results</a>')
  77. def get_urls(self):
  78. urls = super().get_urls()
  79. my_urls = [path('<int:session_id>/eval_download_results/', eval_download_results, name='eval_download_results'),]
  80. return my_urls + urls
  81. def accuracy(self, obj):
  82. return "{:.2%}".format(obj.accuracy)
  83. def progress(self, obj):
  84. total_counts_answered = obj.evalanswer_set.count()
  85. total_counts = QA.objects.filter(dataset=obj.config.dataset).filter(target=obj.dataset_target).count()
  86. return "{}/{}".format(total_counts_answered, total_counts)
  87. admin.site.register(EvalSession, EvalSessionAdmin)
  88. class AnswerInterpreterAdmin(admin.ModelAdmin):
  89. list_display = ('name', 'llm_model')
  90. search_fields = ('name', 'llm_model__name')
  91. ordering = ('-created_at',)
  92. admin.site.register(AnswerInterpreter, AnswerInterpreterAdmin)