Commit 9fe7ff92 authored by Fardale's avatar Fardale

add communication verification for job submission and result

parent f9554ba1
open Lwt.Infix
open Serialization_t
let cmd cpu ram time iteration port script_file pass addr =
let cmd cpu ram time iteration port script_file pass addr () =
let script = CCIO.with_in script_file CCIO.read_all in
let submission = {script; time; iteration; cpu; ram; pass} in
Lwt_io.with_connection
(Lwt_unix.ADDR_INET (Unix.inet_addr_of_string addr, port))
(fun (_ic, oc) ->
Lwt_io.write_line oc (Serialization_j.string_of_query (`JOB submission)) )
(fun (ic, oc) ->
Lwt_io.write_line oc (Serialization_j.string_of_query (`JOB submission))
>>= fun () ->
Lwt_io.read ic
>>= fun json ->
match
CCResult.guard (fun () -> Serialization_j.answer_of_string json)
with
| Ok answer -> (
match answer with
| `Ok ->
Logs_lwt.info (fun m -> m "Computation successfully sent.")
| `Error s ->
Logs_lwt.err (fun m ->
m "Error during the reception of the computation: %s" s ) )
| Error e ->
Logs_lwt.err (fun m ->
m "Error during the reception of the answer: %s"
(Printexc.to_string e) ) )
......@@ -2,6 +2,9 @@ open Lwt.Infix
open Serialization_t
module SHashtbl = CCHashtbl.Make (CCString)
(* TODO: store currently running jobs and check from time
* to time that node are still alive *)
let nodes : (node * computation list) SHashtbl.t = SHashtbl.create 10
let jobs : computation CCDeque.t = CCDeque.create ()
......@@ -60,14 +63,14 @@ let rec launch_job () =
n.port )
>>= fun () ->
Lwt_io.with_connection sockaddr (fun (_ic, oc) ->
Lwt_io.write_line oc (Serialization_j.string_of_query (`COMPUTATION j))
)
Lwt_io.write_line oc
(Serialization_j.string_of_query (`COMPUTATION j)) )
(* TODO: check return value *)
>>= Lwt.pause
>>= launch_job )
else Logs_lwt.debug (fun m -> m "No compuatiton")
else Logs_lwt.debug (fun m -> m "No computation")
let server_handler pass port sockaddr (ic, _oc) =
let server_handler pass port sockaddr (ic, oc) =
Lwt_io.read_line ic
>>= fun json ->
match CCResult.guard (fun () -> Serialization_j.query_of_string json) with
......@@ -75,10 +78,12 @@ let server_handler pass port sockaddr (ic, _oc) =
match query with
| `RESULT result ->
if pass = result.pass then
Logs_lwt.debug (fun m ->
m "Receive result: %s"
(Serialization_j.string_of_result
{result with stdout= "<stdout>"; stderr= "<stderr>"}) )
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive result: %s"
(Serialization_j.string_of_result
{result with stdout= "<stdout>"; stderr= "<stderr>"}) )
>>= fun () ->
Lwt.return
(Lwt.async (fun () ->
......@@ -103,16 +108,17 @@ let server_handler pass port sockaddr (ic, _oc) =
; end_job result.id sockaddr >>= Lwt.pause >>= launch_job ]
))
else
(*Lwt_io.write oc "Wrong password"
<&>*)
Logs_lwt.warn (fun m -> m "Wrong password: %s" result.pass)
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" result.pass)
| `JOB submission ->
if pass = submission.pass then (
(*Lwt_io.write oc "true"
<&>*)
Logs_lwt.debug (fun m ->
m "Receive submission: %s"
(Serialization_j.string_of_submission submission) )
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive submission: %s"
(Serialization_j.string_of_submission submission) )
>>= fun () ->
let empty = CCDeque.is_empty jobs in
for i = 0 to submission.iteration - 1 do
......@@ -133,15 +139,26 @@ let server_handler pass port sockaddr (ic, _oc) =
incr jobs_id ;
if empty then launch_job () else Lwt.return_unit )
else
(*Lwt_io.write oc "Wrong password"
<&>*)
Logs_lwt.warn (fun m -> m "Wrong password: %s" submission.pass)
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" submission.pass)
| _ ->
Logs_lwt.warn (fun m -> m "Receive a unwanted command") )
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Unwanted command"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Receive a unwanted command") )
| Result.Error e ->
Logs_lwt.err (fun m ->
m "Error during the reception of the computation: %s"
(Printexc.to_string e) )
Lwt_io.write_line oc
(Serialization_j.string_of_answer
(`Error
(Printf.sprintf
"Error during the reception of the computation: %s"
(Printexc.to_string e))))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.err (fun m ->
m "Error during the reception of the computation: %s"
(Printexc.to_string e) )
let stop_server resolver server _ = Lwt.wakeup_later resolver server
......
......@@ -41,6 +41,41 @@ let run_computation (computation : computation) =
; ret_code= process_status_to_ret_code ret_code
; pass= computation.pass } )
let rec send_result sockaddr result =
if%lwt
Lwt_io.with_connection sockaddr (fun (ic, oc) ->
Lwt_io.write_line oc (Serialization_j.string_of_query (`RESULT result))
>>= fun () ->
Lwt_io.flush oc
>>= fun () ->
try%lwt
Lwt_io.read_line ic
>>= fun json ->
match
CCResult.guard (fun () -> Serialization_j.answer_of_string json)
with
| Ok answer -> (
match answer with
| `Ok ->
Logs_lwt.debug (fun m ->
m "Result %i,%i successfully sent." (fst result.id)
(snd result.id) )
>>= fun () -> Lwt.return false
| `Error s ->
Logs_lwt.err (fun m ->
m "Error during the reception of the result %i,%i: %s"
(fst result.id) (snd result.id) s )
>>= fun () -> Lwt.return true )
| Error e ->
Logs_lwt.err (fun m ->
m "Error during the reception of the answer: %s"
(Printexc.to_string e) )
>>= fun () -> Lwt.return true
with End_of_file ->
Logs_lwt.err (fun m -> m "Error during the read of the answer: EOF")
>>= fun () -> Lwt.return true )
then send_result sockaddr result
let handle_computation sockaddr computation () =
run_computation computation
>>= fun result ->
......@@ -51,14 +86,10 @@ let handle_computation sockaddr computation () =
| s ->
s
in
Lwt_io.with_connection sockaddr (fun (_ic, oc) ->
Lwt_io.write_line oc (Serialization_j.string_of_query (`RESULT result))
>>= fun () -> Lwt_io.flush oc
)
(* TODO: rendre ça résistant au crash du serveur *)
>>= fun () ->
Logs_lwt.debug (fun m ->
m "End computation %i,%i" (fst computation.id) (snd computation.id) )
send_result sockaddr result
<&> Logs_lwt.debug (fun m ->
m "End computation %i,%i" (fst computation.id) (snd computation.id)
)
let stat oc = Lwt_io.write_line oc (Serialization_j.string_of_stat `OK)
......@@ -70,12 +101,18 @@ let server_handler pass sockaddr (ic, oc) =
match query with
| `COMPUTATION (computation : computation) ->
if pass = computation.pass then
Logs_lwt.debug (fun m ->
m "Receive computation: %s"
(Serialization_j.string_of_computation
{computation with env= []; script= "<script>"}) )
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive computation: %s"
(Serialization_j.string_of_computation
{computation with env= []; script= "<script>"}) )
>|= fun () -> Lwt.async (handle_computation sockaddr computation)
else Logs_lwt.warn (fun m -> m "Wrong password: %s" computation.pass)
else
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" computation.pass)
| `STAT ->
stat oc <&> Logs_lwt.debug (fun m -> m "Receive a stat command")
| _ ->
......
......@@ -106,7 +106,7 @@ let client_cmd =
( Term.(
const Lwt_main.run
$ ( const Client.cmd $ cpu $ ram $ time $ iteration $ port $ script $ pass
$ addr ))
$ addr $ setup_log ))
, Term.info "client" ~doc ~sdocs:Manpage.s_common_options ~exits ~man )
let default_cmd =
......
......@@ -13,6 +13,8 @@ type submission = {script: string; time: float option; iteration: int <ocaml def
type query = [COMPUTATION of computation | STAT | RESULT of result | JOB of submission]
type answer = [Ok | Error of string]
type stat = [OK]
type result =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment